from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping
from datetime import datetime
from itertools import product
from functools import partial
from typing import Dict, Callable
import time
import logging

import matplotlib.pyplot as plt
import matplotlib.tri as tri
import numpy as np
import tqdm

import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit, hessian, lax
from jax.scipy.special import logsumexp
from jax.example_libraries import optimizers
from jax.nn import relu
from jax.config import config
from jax.flatten_util import ravel_pytree
import optax
import jaxopt

import torch
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.acquisition import qExpectedImprovement, qUpperConfidenceBound, qSimpleRegret
from botorch.optim.initializers import gen_batch_initial_conditions, initialize_q_batch_nonneg
from botorch.generation import gen_candidates_torch, get_best_candidates
from botorch.sampling.stochastic_samplers import StochasticSampler

from ..icbc_patch import generate_residue
from ..models.pinn_ensemble import PINNEnsemble
from ..utils import sample_from_uniform
from ..utils.jax_utils import flatten, vmap_mjp, vmap_jmp, jacobian_outer_product
from ..utils.vmap_chunked import vmap_chunked
# from .ed_loop import ExperimentalDesign
from .criterion_based import CriterionBasedAbstractMethod
from .eig_estimators.losses import generate_loss

# logger for this file
logging.getLogger().setLevel(logging.INFO)


class PINNTolerableInverseParams(CriterionBasedAbstractMethod):
    
    def __init__(self, simulator_xs, pde, pde_domain, exp_design_fn, obs_design_fn,
                 inv_embedding, inv_param_in_domain, exp_in_domain, obs_in_domain,
                 inv_input_dim, exp_input_dim, obs_input_dim, obs_reading_count,
                 x_input_dim, y_output_dim, 
                 ensemble_size: int = 100, pinn_ensemble_args: Dict = dict(), 
                 pinn_share_init: bool = False, pinn_init_meta_rounds: int = 0, pinn_init_meta_steps: int = 1000, pinn_meta_eps: float = 0.1,
                 ensemble_steps: int = 100000, acq_fn: str = 'ucb', 
                 approx_param_method: str = 'self', chunk_size: int = 16, reg: float = 1e-6,
                 pde_colloc_sample_num: int = 1000, icbc_colloc_sample_num: int = 100, 
                 approx_H_inv: bool = True, approx_H_nn: bool = True, pinv_jac: bool = False, use_jmp: bool = False, 
                 consider_embedding: bool = False, max_inv_embedding_dim: int = 100,
                 exp_setup_rounds: int = 10, obs_setup_rounds: int = 10, obs_search_time_limit: float = 3600., noise_std: float = 1e-3, 
                 obs_optim_use_lbfgs: bool = False, obs_optim_gd_params: Dict = dict(stepsize=1e-2, maxiter=100, acceleration=True), 
                 obs_optim_grad_clip: float = None, obs_optim_grad_jitter: float = None, obs_optim_grad_zero_rate: float = None,
                 min_obs_rounds: int = 3, seed: int = 0):
        super().__init__(
            simulator_xs=simulator_xs,
            pde=pde, 
            pde_domain=pde_domain, 
            exp_design_fn=exp_design_fn, 
            obs_design_fn=obs_design_fn,
            inv_embedding=inv_embedding, 
            inv_param_in_domain=inv_param_in_domain, 
            exp_in_domain=exp_in_domain, 
            obs_in_domain=obs_in_domain,
            inv_input_dim=inv_input_dim, 
            exp_input_dim=exp_input_dim, 
            obs_input_dim=obs_input_dim, 
            obs_reading_count=obs_reading_count,
            x_input_dim=x_input_dim,
            y_output_dim=y_output_dim,
            use_pinns=True,
            ensemble_size=ensemble_size,
            ensemble_steps=ensemble_steps,
            pinn_ensemble_args=pinn_ensemble_args,
            pinn_share_init=pinn_share_init,
            pinn_init_meta_rounds=pinn_init_meta_rounds,
            pinn_init_meta_steps=pinn_init_meta_steps,
            pinn_meta_eps=pinn_meta_eps,
            acq_fn=acq_fn,
            exp_setup_rounds=exp_setup_rounds,
            obs_setup_rounds=obs_setup_rounds,
            obs_search_time_limit=obs_search_time_limit,
            noise_std=noise_std,
            obs_optim_with_gd=True,
            obs_optim_use_lbfgs=obs_optim_use_lbfgs,
            do_jit=True,
            obs_optim_gd_params=obs_optim_gd_params,
            obs_optim_grad_clip=obs_optim_grad_clip,
            obs_optim_grad_jitter=obs_optim_grad_jitter,
            obs_optim_grad_zero_rate=obs_optim_grad_zero_rate,
            min_obs_rounds=min_obs_rounds,
            seed=seed,
        )
        
        self.approx_param_method = approx_param_method
        self.pde_colloc_sample_num = pde_colloc_sample_num
        self.icbc_colloc_sample_num = icbc_colloc_sample_num
        self.approx_H_inv = approx_H_inv
        self.approx_H_nn = approx_H_nn
        self.pinv_jac = pinv_jac
        self.use_jmp = use_jmp
        # self.use_lissa = use_lissa
        # self.lissa_recursion_depth = lissa_recursion_depth
        # self.lissa_scale = lissa_scale
        # self.lissa_damping = lissa_damping
        self.consider_embedding = consider_embedding
        self.max_inv_embedding_dim = max_inv_embedding_dim
        self.chunk_size = chunk_size
        self.reg = reg
        
    def _generate_criterion_inner(self, exp_design, true_inv_prior_samples):
        
        if self.approx_param_method == 'self':
            
            tracked_params = [self.forward_ens.get_param(i)['net'] for i in range(len(self.forward_ens.inv_params))]
            
            def modified_apply(p, xs):
                return self.forward_ens.net.apply(p, xs)
            
        else:
            raise ValueError
        
        tracked_params_dim = flatten(tracked_params[0])[0].shape[0]
        tracked_params_stacked = jax.tree_map(lambda *x: jnp.array(jnp.stack(x), dtype=jnp.zeros(1).dtype), *tracked_params)
        inv_params = self.forward_ens.inv_params
        
        dinv_zero = jnp.zeros_like(inv_params[0])
        idxs_sample = jnp.array(np.random.choice(
            self.forward_ens.pde_collocation_pts.shape[0], 
            size=min(self.forward_ens.pde_collocation_pts.shape[0], self.pde_colloc_sample_num), 
            replace=False
        ))
        xs_colloc = self.forward_ens.pde_collocation_pts[idxs_sample]
        ys_shape = self.obs_design_fn(
            lambda xs: modified_apply(tracked_params[0], xs), 
            self.obs_in_domain[:,0]
        ).shape
        
        if self.consider_embedding:
            
            inv_embedding = jax.vmap(self.inv_embedding)(inv_params)
            inv_embedding_dim = inv_embedding[0].shape[0]
            # J_shape = (inv_embedding_dim, self.inv_input_dim)
            # Hs_shape = (self.ensemble_size, inv_embedding_dim, inv_embedding_dim)
            
            if inv_embedding_dim > self.max_inv_embedding_dim:
                idxs_condense = jax.random.choice(self.get_rng(), inv_embedding_dim, shape=(self.max_inv_embedding_dim,), replace=False)
                new_inv_embedding_fn = jax.jit(lambda inv: self.inv_embedding(inv)[idxs_condense])
                inv_embedding = jax.vmap(new_inv_embedding_fn)(inv_params)
                inv_embedding_dim = self.max_inv_embedding_dim
            else:
                idxs_condense = None
                new_inv_embedding_fn = self.inv_embedding
                
            demb_zero = jnp.zeros_like(inv_embedding[0])
            
            @jax.jit
            def estimate_dinv_learnable(demb, emb0, inv0):
                # noise = self.reg * jax.random.normal(rng, shape=J_shape)
                J = jax.jacobian(new_inv_embedding_fn)(inv0)
                J_inv = jnp.linalg.pinv(J)
                return J_inv @ demb
            
        else:
            
            new_inv_embedding_fn = None
            estimate_dinv_learnable = None
        
        
        def constraint_indiv(beta, params):
            f_ = lambda xs: modified_apply(params, xs)
            return jnp.mean(self.pde(xs_colloc, (f_(xs_colloc), f_), beta, exp_design)[0][:,0] ** 2)
        
        bc_fns = [fn for (fn, _) in self.forward_ens.exp_design_fn]
        bc_data = [xs[:min(xs.shape[0], self.icbc_colloc_sample_num)] for (_, xs) in self.forward_ens.exp_design_fn]
        
        def icbc_loss(params, b):
            loss = 0.
            for bc_fn, bc_xs in zip(bc_fns, bc_data):
                loss += jnp.mean(bc_fn(params, modified_apply, self.forward_ens.exp_params, b, bc_xs) ** 2)
            return loss
                    
        @jax.jit
        def anchor_loss(p, b, current_param, obs_params, rng):
            ys_pred = self.obs_design_fn(lambda xs: modified_apply(p, xs), obs_params).reshape(-1)
            ys = self.obs_design_fn(lambda xs: modified_apply(current_param, xs), obs_params)
            ys_noisy = ys + self.noise_std * jax.random.normal(key=rng, shape=ys.shape)
            return jnp.mean((ys_pred - ys_noisy) ** 2)
        
        @jax.jit
        def forward_loss(p, b, current_param, obs_params):
            return constraint_indiv(b, p) + icbc_loss(p, b)
        
        @jax.jit
        def loss(p, b, current_param, obs_params, rng):
            return constraint_indiv(b, p) + icbc_loss(p, b) + anchor_loss(p, b, current_param, obs_params, rng)
        
        @jax.jit
        def separate_residues(nn_params, obs_params, beta, anchor_params, rng):
            f_ = lambda xs: modified_apply(nn_params, xs)
            reading_fn = lambda p: self.obs_design_fn((lambda xs: modified_apply(p, xs)), obs_params)
            ys = reading_fn(anchor_params).reshape(-1)
            ys_noisy = ys + self.noise_std * jax.random.normal(key=rng, shape=ys.shape)
            return {
                'anc': (reading_fn(nn_params) - ys_noisy).reshape(-1) / (self.obs_reading_count ** 0.5),
                'pde': self.pde(xs_colloc, (f_(xs_colloc), f_), beta, exp_design)[0][:,0].reshape(-1) / (xs_colloc.shape[0] ** 0.5),
                'bcs': [
                    bc_fn(nn_params, modified_apply, self.forward_ens.exp_params, beta, bc_xs).reshape(-1) / (bc_xs.shape[0] ** 0.5)
                    for bc_fn, bc_xs in zip(bc_fns, bc_data)
                ]
            }
            
        @jax.jit
        def anchor_residues(nn_params, obs_params, beta, anchor_params, rng):
            f_ = lambda xs: modified_apply(nn_params, xs)
            reading_fn = lambda p: self.obs_design_fn((lambda xs: modified_apply(p, xs)), obs_params)
            ys = reading_fn(anchor_params).reshape(-1)
            ys_noisy = ys + self.noise_std * jax.random.normal(key=rng, shape=ys.shape)
            return {'anc': (reading_fn(nn_params) - ys_noisy).reshape(-1) / (self.obs_reading_count ** 0.5)}
        
        @jax.jit
        def colloc_residues(nn_params, obs_params, beta, anchor_params):
            f_ = lambda xs: modified_apply(nn_params, xs)
            # reading_fn = lambda p: self.obs_design_fn((lambda xs: modified_apply(p, xs)), obs_params)
            return {
                'pde': self.pde(xs_colloc, (f_(xs_colloc), f_), beta, exp_design)[0][:,0].reshape(-1) / (xs_colloc.shape[0] ** 0.5),
                'bcs': [
                    bc_fn(nn_params, modified_apply, self.forward_ens.exp_params, beta, bc_xs).reshape(-1) / (bc_xs.shape[0] ** 0.5)
                    for bc_fn, bc_xs in zip(bc_fns, bc_data)
                ]
            }

        @partial(jax.jit, static_argnames=['pinv_jac', 'approx_H'])
        def get_Hinv(beta0, current_param, obs_params, approx_H=False, pinv_jac=False):
            
            pflat, unflatten = flatten(current_param)
            # orig_loss = lambda p: constraint_indiv(beta0, p) + icbc_loss(p, beta0) + anchor_loss(p, beta0, current_param, obs_params)
            orig_loss = lambda p: constraint_indiv(beta0, p) + icbc_loss(p, beta0)
            
            if approx_H:
                res_flat = lambda p: jnp.concatenate(jax.tree_util.tree_flatten(colloc_residues(unflatten(p), obs_params, beta0, current_param))[0])
                J = jax.jacfwd(res_flat)(pflat)
                if not pinv_jac:
                    H = 2. * J.T @ J
                
            else:
                H = jax.hessian(lambda t: orig_loss(unflatten(t)))(pflat)
                # rcond = None
                
            # rcond = None
            
            if pinv_jac:
                
                J_inv = jnp.linalg.pinv(J).T
                H_inv = 0.5 * J_inv.T @ J_inv
            
            else:
                
                H = H + self.reg * jnp.eye(tracked_params_dim)
                H_inv = jnp.linalg.pinv(H, hermitian=True)        
                
            return H_inv
        
        @partial(jax.jit, static_argnames=['pinv_jac', 'approx_H', 'use_jmp'])
        def get_param_change(dinv, beta0, current_param, obs_params, approx_H=False, pinv_jac=False, use_jmp=True):
            
            pflat, unflatten = flatten(current_param)
            beta_shifted = beta0 + dinv
            perturbed_loss = lambda p: forward_loss(p, beta_shifted, current_param, obs_params) - forward_loss(p, beta0, current_param, obs_params)
            
            H_inv = get_Hinv(beta0, current_param, obs_params, approx_H, pinv_jac)
                            
            g = jax.jacobian(lambda t: perturbed_loss(unflatten(t)))(pflat)
            res_params_change = unflatten(- H_inv @ g)
            
            new_res_param = jax.tree_util.tree_map(lambda x, y: x + y, current_param, res_params_change)
            return new_res_param
        
        @partial(jax.jit, static_argnames=['pinv_jac', 'approx_H', 'use_jmp'])
        def perturbed_loss_fn_learnable(dinv, obs_params, beta0, current_param, rng, approx_H=False, pinv_jac=False, use_jmp=True):
            new_res_param = get_param_change(
                dinv=dinv, 
                beta0=beta0, 
                current_param=current_param, 
                obs_params=obs_params,
                approx_H=approx_H, 
                pinv_jac=pinv_jac,
                use_jmp=use_jmp,
            )
            new_beta = beta0 + dinv
            # return loss(p=new_res_param, b=new_beta, current_param=current_param, obs_params=obs_params)
            return anchor_loss(p=new_res_param, b=new_beta, current_param=current_param, obs_params=obs_params, rng=rng)
            
        @partial(jax.jit, static_argnames=['pinv_jac', 'approx_H', 'use_jmp'])
        def perturbed_residue_fn_learnable(dinv, obs_params, beta0, current_param, rng, approx_H=False, pinv_jac=False, use_jmp=True):
            new_res_param = get_param_change(
                dinv=dinv, 
                beta0=beta0, 
                current_param=current_param, 
                obs_params=obs_params,
                approx_H=approx_H, 
                pinv_jac=pinv_jac,
                use_jmp=use_jmp,
            )
            new_beta = beta0 + dinv
            # return separate_residues(nn_params=new_res_param, obs_params=obs_params, beta=new_beta, anchor_params=current_param)
            return anchor_residues(nn_params=new_res_param, obs_params=obs_params, beta=new_beta, anchor_params=current_param, rng=rng)
        
        if self.consider_embedding:
                        
            @partial(jax.jit, static_argnames=['pinv_jac', 'approx_H', 'use_jmp'])
            def perturbed_loss_fn(demb, obs_params, emb0, inv0, current_param, rng, 
                                  approx_H=False, pinv_jac=False, use_jmp=True):
                dinv = estimate_dinv_learnable(demb=demb, emb0=emb0, inv0=inv0)
                return perturbed_loss_fn_learnable(dinv=dinv, obs_params=obs_params, beta0=inv0, current_param=current_param, rng=rng,
                                                   approx_H=approx_H, pinv_jac=pinv_jac, use_jmp=use_jmp)
                        
            @partial(jax.jit, static_argnames=['pinv_jac', 'approx_H', 'use_jmp'])
            def perturbed_residue_fn(demb, obs_params, emb0, inv0, current_param, rng,
                                     approx_H=False, pinv_jac=False, use_jmp=True):
                dinv = estimate_dinv_learnable(demb=demb, emb0=emb0, inv0=inv0)
                return perturbed_residue_fn_learnable(dinv=dinv, obs_params=obs_params, beta0=inv0, current_param=current_param, rng=rng,
                                                      approx_H=approx_H, pinv_jac=pinv_jac, use_jmp=use_jmp)
                
            @partial(jax.jit, static_argnames=['approx_H_inv', 'pinv_jac', 'approx_H_nn', 'use_jmp'])
            def hessian_approx(obs_params, emb0, inv0, current_param, rng, approx_H_inv=False, approx_H_nn=True, pinv_jac=False, use_jmp=True):
                if approx_H_inv:
                    J = jax.jacrev(perturbed_residue_fn)(demb_zero, obs_params, emb0, inv0, current_param, rng=rng,
                                                         approx_H=approx_H_nn, pinv_jac=pinv_jac, use_jmp=use_jmp)
                    return 2. * jax.tree_util.tree_reduce(jnp.add, jax.tree_map(lambda x: x.T @ x, J))
                else:
                    return jax.hessian(perturbed_loss_fn)(demb_zero, obs_params, emb0, inv0, current_param, rng=rng,
                                                          approx_H=approx_H_nn, pinv_jac=pinv_jac, use_jmp=use_jmp)
            
            @partial(jax.jit, static_argnames=['approx_H_inv', 'pinv_jac', 'approx_H_nn', 'use_jmp'])
            def hessian_approx_batched(obs_params, rng, approx_H_inv=False, approx_H_nn=True, pinv_jac=False, use_jmp=True):
                rng_split = jax.random.split(rng, num=self.ensemble_size)
                Hs = vmap_chunked(
                    hessian_approx, 
                    in_axes=(None, 0, 0, 0, 0, None, None, None, None),
                    chunk_size=self.chunk_size,
                )(obs_params, inv_embedding, inv_params, tracked_params_stacked, rng_split, approx_H_inv, approx_H_nn, pinv_jac, use_jmp)
                return Hs
            
            @jax.jit
            def criterion(obs_params, rng=jax.random.PRNGKey(0)):
                Hs = hessian_approx_batched(
                    obs_params, 
                    rng,
                    approx_H_inv=self.approx_H_inv, 
                    approx_H_nn=self.approx_H_nn, 
                    pinv_jac=self.pinv_jac, 
                    use_jmp=self.use_jmp,
                )
                sign_H, logdet_H = jax.vmap(lambda H: jnp.linalg.slogdet(H + self.reg * jnp.eye(inv_embedding_dim)))(Hs)
                # score = jnp.nanmean(logdet_H)
                # for the Hs whose det is less than zero
                score = jnp.nanmean(jax.nn.relu(sign_H) * logdet_H)
                return score, dict(Hs=Hs)
            
        else:
            
            perturbed_loss_fn = perturbed_loss_fn_learnable
            perturbed_residue_fn = perturbed_residue_fn_learnable
        
            @partial(jax.jit, static_argnames=['approx_H_inv', 'pinv_jac', 'approx_H_nn', 'use_jmp'])
            def hessian_approx(obs_params, mean_beta, current_param, rng, approx_H_inv=False, approx_H_nn=True, pinv_jac=False, use_jmp=True):
                if approx_H_inv:
                    J = jax.jacrev(perturbed_residue_fn)(dinv_zero, obs_params, mean_beta, current_param, rng, approx_H_nn, pinv_jac, use_jmp)
                    return 2. * jax.tree_util.tree_reduce(jnp.add, jax.tree_map(lambda x: x.T @ x, J))
                else:
                    return jax.hessian(perturbed_loss_fn)(dinv_zero, obs_params, mean_beta, current_param, rng, approx_H_nn, pinv_jac, use_jmp)
            
            @partial(jax.jit, static_argnames=['approx_H_inv', 'pinv_jac', 'approx_H_nn', 'use_jmp'])
            def hessian_approx_batched(obs_params, rng, approx_H_inv=False, approx_H_nn=True, pinv_jac=False, use_jmp=True):
                rng_split = jax.random.split(rng, num=self.ensemble_size)
                # Hs = jax.vmap(hessian_approx_via_jac, in_axes=(None, 0, 0))(obs_params, inv_params, tracked_params_stacked)
                Hs = vmap_chunked(
                    hessian_approx, 
                    in_axes=(None, 0, 0, 0, None, None, None, None),
                    chunk_size=self.chunk_size,
                )(obs_params, inv_params, tracked_params_stacked, rng_split, approx_H_inv, approx_H_nn, pinv_jac, use_jmp)
                return Hs
        
            @jax.jit
            def criterion(obs_params, rng=jax.random.PRNGKey(0)):
                Hs = hessian_approx_batched(
                    obs_params, 
                    rng,
                    approx_H_inv=self.approx_H_inv, 
                    approx_H_nn=self.approx_H_nn, 
                    pinv_jac=self.pinv_jac, 
                    use_jmp=self.use_jmp,
                )
                sign_H, logdet_H = jax.vmap(lambda H: jnp.linalg.slogdet(H + self.reg * jnp.eye(self.inv_input_dim)))(Hs)
                # score = jnp.nanmean(logdet_H)
                # for the Hs whose det is less than zero
                score = jnp.nanmean(jax.nn.relu(sign_H) * logdet_H)
                return score, dict(Hs=Hs)
        
        helper_fns = {
            'tracked_params': tracked_params, 
            'true_inv_prior_samples': true_inv_prior_samples,
            'modified_apply': modified_apply,
            'new_inv_embedding_fn': new_inv_embedding_fn,
            'estimate_dinv_learnable': estimate_dinv_learnable,
            'pde_loss': constraint_indiv,
            'icbc_loss': icbc_loss,
            'anchor_loss': anchor_loss,
            'total_loss': loss,
            'H_inv': get_Hinv,
            'separate_residues': separate_residues,
            'colloc_residues': colloc_residues,
            'get_param_change': get_param_change,
            'perturbed_loss_fn': perturbed_loss_fn,
            'perturbed_residue_fn': perturbed_residue_fn,
            'perturbed_loss_fn_learnable': perturbed_loss_fn_learnable,
            'perturbed_residue_fn_learnable': perturbed_residue_fn_learnable,
            'hessian_approx': hessian_approx,
            'hessian_approx_batched': hessian_approx_batched,
            'criterion': criterion,   
        }
        return criterion, helper_fns
