import jax
from jax import vmap, lax, numpy as jnp
from jax.flatten_util import ravel_pytree as ravel
from .stats.hmm import *
from .utils.misc import *
from src.stats import BackwardSmoother
from src.variational import NeuralBackwardSmoother
vmap_ravel = jax.vmap(lambda x: ravel_pytree(x)[0])
import blackjax
from jax import dtypes

def value_and_jac(fun, argnums=0, has_aux=False):
    if has_aux:
        return lambda *args: (fun(*args), jax.jacobian(fun, argnums=argnums, has_aux=True)(*args)[0])
    else: 
        return lambda *args: (fun(*args), jax.jacobian(fun, argnums=argnums, has_aux=False)(*args))


class OnlineVariationalAdditiveSmoothing:
    """Base class to compute expectations of addtive state functionals via our online algorithm"""
    def __init__(self, 
                p:HMM, 
                q:BackwardSmoother, 
                additive_functional, 
                init_carry_fn,
                init_fn, 
                update_fn, 
                preprocess_fn,
                postprocess_fn,
                **options):

        self.p = p
        self.q = q

        self.additive_functional:AdditiveFunctional = additive_functional(p,q)

        self.options = options


        self._preprocess_fn = preprocess_fn
        self._init_carry_fn = init_carry_fn
        self._init_fn = init_fn
        self._update_fn = update_fn
        self._postprocess_fn = postprocess_fn

    def _init(self, carry, input):
        return self._init_fn(carry, input, 
                             p=self.p, 
                             q=self.q, 
                             h=self.additive_functional.init, 
                             **self.options)
        
    def _update(self, carry, input):
        return self._update_fn(carry, input, 
                               p=self.p,
                               q=self.q, 
                               h=self.additive_functional.update, 
                               **self.options)
    
    def init_carry(self, params,params_model):

        return self._init_carry_fn(params, 
                                   params_model,
                                   p=self.p,
                                   q=self.q,
                                   h=self.additive_functional,
                                   **self.options)
    
    def step(self, carry, input):
        
    
        
        carry, output = lax.cond(input['t'] != 0, 
                        self._update, 
                        self._init,
                        carry, input)

        return carry, output
    
    def preprocess(self, obs_seq, **kwargs):
        return self._preprocess_fn(obs_seq, **kwargs, **self.options)

    def batch_compute(self, key, strided_obs_seq, theta, phi):
        """Utility member function to compute the expectation on a sequence of fixed length, processing all data at once."""


        T = len(strided_obs_seq) - 1 # T + 1 observations

        keys = jax.random.split(key, T+1) # T+1 keys 
        timesteps = jnp.arange(0, T+1) # [0:T]

        
        def _step(carry, x):
            t, key_t, strided_ys = x
            input_t = {'t':t, 
                       'key':key_t, 
                       'T': T, 
                       'ys_bptt':strided_ys, 
                       'phi':phi}            
            carry['theta'] = theta
            carry_t, output_t = self.step(carry, input_t)
            return carry_t, output_t
        

        carry_m1 = self.init_carry(phi)

    
        return lax.scan(_step, 
                        init=carry_m1,
                        xs=(timesteps, keys, strided_obs_seq))
    
    def postprocess(self, carry, **kwargs):
        return self._postprocess_fn(carry, **kwargs, **self.options)



def init_carry(unformatted_params,params_model, **kwargs):


    num_samples = kwargs['num_samples']
    out_shape = kwargs['h'].out_shape
    state_dim = kwargs['p'].state_dim
    dummy_state = kwargs['q'].empty_state()
    

    dummy_tau = jnp.empty((num_samples, *out_shape))
    dummy_x = jnp.empty((num_samples, state_dim)) 
    dummy_log_q_x = jnp.empty((num_samples,))


    return {'state':dummy_state, 
            'log_q_x':dummy_log_q_x, 
            'x':dummy_x, 
            'stats':{'tau':dummy_tau}}





def preprocess_for_bptt(obs_seq, bptt_depth, **kwargs):

    padded_ys = jnp.concatenate([jnp.empty((bptt_depth-1, obs_seq.shape[1])), 
                                 obs_seq])
    
    strided_ys = tree_get_strides(bptt_depth, padded_ys)

    return strided_ys 
    

    
def init_carry_elbo_score_gradients(unformatted_params,params_model,**kwargs):
#    print(f"unformat params = {unformatted_params}")

    num_samples = kwargs['num_samples']
    state_dim = kwargs['p'].state_dim
    dummy_state = kwargs['q'].empty_state()
    out_shape = kwargs['h'].out_shape
    dummy_x = jnp.zeros((num_samples, state_dim))
    dummy_H = jnp.zeros((num_samples, *out_shape))
    
    

    dummy_F = jax.jacrev(lambda phi:dummy_H)(unformatted_params)

    dummy_G = jax.jacrev(lambda theta: dummy_H,allow_int = True)(params_model)
    dummy_B = jax.jacrev(lambda theta: dummy_H,allow_int = True)(params_model)
    
    # s are the latent states
    carry = {'base_s': dummy_state, 
             's':dummy_state,
            'x':dummy_x, 
            'log_q':jnp.zeros((num_samples,)),
            'stats':{'H':dummy_H, 
                    'F':dummy_F,
                    'G':dummy_G,
                    'B':dummy_B},
            'grad_log_q':dummy_F}
    
    return carry

def init_elbo_score_gradients(carry_m1, input_0, **kwargs):


    y_0 = input_0['ys_bptt'][-1]
    key_0, unformatted_phi_0,unformatted_theta_0 = input_0['key'], input_0['phi'],input_0['theta']

    bptt_depth = kwargs['bptt_depth']
    p:HMM = kwargs['p']
    q:BackwardSmoother = kwargs['q']
    num_samples = kwargs['num_samples']

    def get_state(phi):
        if bptt_depth == 1:
            s_t = q.init_state(y_0, phi)
            return s_t, s_t
        # s are the latent states
        base_s, (_, s_0) = q.get_states(0, 
                            carry_m1['base_s'],
                            input_0['ys_bptt'], 
                            phi)
        
        return base_s, s_0
    

    def _log_q_0(unformatted_phi, key):
        phi = q.format_params(unformatted_phi)

        base_s, s_0 = get_state(phi)
        params_q_t = q.filt_params_from_state(s_0, phi)
        x_t = q.filt_dist.sample(key, params_q_t)
        x_t = jax.lax.stop_gradient(x_t)
        return q.filt_dist.logpdf(x_t, params_q_t), (x_t, base_s, s_0, params_q_t)
    
    
        

    (log_q_0, (x_0, base_s, s_0, params_q_0)), grad_log_q_0 = jax.vmap(jax.value_and_grad(_log_q_0, argnums=0, has_aux=True), 
                            in_axes=(None,0))(unformatted_phi_0, 
                                                jax.random.split(key_0, num_samples))
                                              
    
    base_s, s_0, params_q_0 = tree_get_idx(0, (base_s, s_0, params_q_0))
    

    theta:HMM.Params = carry_m1['theta_true']

    theta:HMM.Params = p.format_params(unformatted_theta_0,precompute=['prec'])

    
    def _log_l_0(x_0,unformatted_model_params):

        model_params = p.format_params(unformatted_model_params,precompute=['prec'])
        
        return p.prior_dist.logpdf(x_0, model_params.prior) \
            + p.emission_kernel.logpdf(y_0, x_0, model_params.emission)
    
    
    

    grad_log_l_0 = jax.vmap(jax.grad(_log_l_0, argnums=1), in_axes=(0, None))(x_0, unformatted_theta_0)
   
    
    
    
    def _h(x_0):
        return p.prior_dist.logpdf(x_0, theta.prior) \
            + p.emission_kernel.logpdf(y_0, x_0, theta.emission)
    
    H_0 = jax.vmap(_h)(x_0)

    F_0 = jax.tree.map(lambda x: jnp.zeros_like(x), 
                   carry_m1['stats']['F'])

    
    G_0 = jax.tree.map(lambda x: x, grad_log_l_0)
    B_0 = jax.tree.map(lambda x: jnp.zeros_like(x), 
                   carry_m1['stats']['G'])
    

    
    carry = {'stats':{'F':F_0, 
                      'H':H_0,
                      'G':G_0,
                      'B':B_0},
            's':s_0,
            'base_s':base_s, 
            'x':x_0,
            'log_q':log_q_0,
            'grad_log_q':grad_log_q_0}
    

    return carry, (params_q_0, q.backwd_params_from_states((s_0, s_0), q.format_params(unformatted_phi_0)))

def update_elbo_score_gradients(carry_tm1, input_t, **kwargs):
    
    # Main update function

    p:HMM = kwargs['p']
    q:BackwardSmoother = kwargs['q']
    num_samples = kwargs['num_samples']
    resampling = kwargs['resampling']
    bptt_depth = kwargs['bptt_depth']
    mcmc = kwargs['mcmc']
    normalizer = kwargs['normalizer']

    t, T, key_t, unformatted_phi_t,unformatted_theta_t = input_t['t'], input_t['T'], input_t['key'], input_t['phi'],input_t['theta']

    ys_for_bptt = input_t['ys_bptt']
    y_t = ys_for_bptt[-1]
    
    
    x_tm1, base_s_tm1, s_tm1, stats_tm1, theta_true = carry_tm1['x'], carry_tm1['base_s'], carry_tm1['s'], carry_tm1['stats'], carry_tm1['theta_true']

    theta_true = p.format_params(jax.lax.stop_gradient(unformatted_theta_t),precompute=['prec'])
    
    log_q_tm1 = carry_tm1['log_q']

    H_tm1 =  stats_tm1['H']
    
    F_tm1 = stats_tm1['F']
    
    G_tm1 = stats_tm1['G']
    
    
    

    def get_states(phi):
        if bptt_depth == 1:
            s_t = q.new_state(y_t, s_tm1, phi)
            return s_t, (s_tm1, s_t)
        
        return q.get_states(t, 
                            base_s_tm1,
                            ys_for_bptt, 
                            phi)


    def _log_q_tm1_t(unformatted_phi, x_tm1, x_t):
        phi = q.format_params(unformatted_phi)
        
        _, (s_tm1, s_t) = get_states(phi)
        params_q_tm1_t = q.backwd_params_from_states((s_tm1, s_t), phi)

        log_q_tm1_t = q.backwd_kernel.logpdf(x_tm1, 
                                             x_t, 
                                             params_q_tm1_t)
        return log_q_tm1_t, params_q_tm1_t
    
    def _log_q_t(unformatted_phi, key):
        phi = q.format_params(unformatted_phi)
        base_s_t, (_ , s_t) = get_states(phi)
        
        params_q_t = q.filt_params_from_state(s_t, phi)
        x_t = q.filt_dist.sample(key, params_q_t)
        x_t = jax.lax.stop_gradient(x_t)
        return q.filt_dist.logpdf(x_t, params_q_t), (x_t, base_s_t, s_t, params_q_t)
    
    def _log_q_t_and_dummy_grad(unformatted_phi, key):
        log_q_t, (x_t, base_s_t, s_t, params_q_t) = _log_q_t(unformatted_phi, key)
        dummy_grad = jax.tree.map(lambda x: jnp.zeros_like(x[0]), carry_tm1['grad_log_q'])
        return (log_q_t, (x_t, base_s_t, s_t, params_q_t)), dummy_grad
    
    def _log_q_t_and_grad(unformatted_phi, key):
        return jax.value_and_grad(_log_q_t, has_aux=True)(
                                                        unformatted_phi, 
                                                        key)
    
    def _log_l_t(unformatted_theta,x_tm1,x_t):
        theta = p.format_params(unformatted_theta,precompute=['prec'])

        return (p.transition_kernel.logpdf(x_t, x_tm1,theta.transition) \
            + p.emission_kernel.logpdf(y_t, x_t, theta.emission))
            
    
        
    
        
    def update(key_t):
    
        if resampling:
            key_new_sample, key_resampling = jax.random.split(key_t, 2)
        else: 
            key_new_sample = key_t
        

        (log_q_t, (x_t, base_s_t, s_t, params_q_t)), grad_log_q_t = lax.cond(t == T, 
                                                _log_q_t_and_grad,
                                                _log_q_t_and_dummy_grad, 
                                                unformatted_phi_t, 
                                                key_new_sample)
        
        

        def _h(x_tm1, log_q_tm1_t):
            return p.transition_kernel.logpdf(x_t, x_tm1, theta_true.transition) \
                + p.emission_kernel.logpdf(y_t, x_t, theta_true.emission) - log_q_tm1_t
            
        _vmaped_h = jax.vmap(_h, in_axes=(0,0))
        

        if not resampling: 
            
             # SNIS estimator 
            (log_q_tm1_t, params_q_tm1_t), grad_log_q_tm1_t = jax.vmap(jax.value_and_grad(_log_q_tm1_t, has_aux=True),
                                                     in_axes=(None,0,None))(unformatted_phi_t, 
                                                                         x_tm1, 
                                                                         x_t)

            h_t = _vmaped_h(x_tm1, log_q_tm1_t)

            log_w_t = log_q_tm1_t - log_q_tm1
            w_t = normalizer(log_w_t)
            
            
            H_t = jax.vmap(lambda w, H, h: w * (H+h))(w_t, H_tm1, h_t)
            H_t = jnp.sum(H_t, axis=0)

            control_variate = H_t

            F_t = tree.map(lambda F, grad_log_backwd: jax.vmap(lambda w, F, H, h, grad_log_backwd: w*(F + grad_log_backwd*(H+h-control_variate)))(
                                                         w_t, 
                                                         F, 
                                                         H_tm1,
                                                         h_t, 
                                                         grad_log_backwd), 
                                                         F_tm1, 
                                                         grad_log_q_tm1_t)
            
            F_t = tree.map(lambda x: jnp.sum(x, axis=0), F_t)
            
            
            
            grad_log_l_all = jax.vmap(jax.grad(_log_l_t, has_aux=False, allow_int=True),
                 in_axes=(None, 0, None)
             )(unformatted_theta_t, x_tm1, x_t)  # each leaf: [S, ...]
            # # weight and sum over S to get param-shaped leaves
            G_t = tree.map(
                 lambda Gprev, gl: jnp.sum(
                     w_t.reshape((w_t.shape[0],) + (1,) * (gl.ndim - 1)) * (Gprev + gl),
                     axis=0
                 ),
                 G_tm1, grad_log_l_all
            )
            

 
        else: 
            log_q_tm1_t, params_q_tm1_t = jax.vmap(_log_q_tm1_t,
                                in_axes=(None,0,None))(unformatted_phi_t, 
                                                    x_tm1, 
                                                    x_t)
            
                                                       
                                                       
                                                       
            # log_l_t = jax.vmap(_log_l_t,
            #                    in_axes =(None,0,None))(unformatted_theta_t,
            #                                            x_tm1,x_t)
                               
                               
            

            

            log_w_t = log_q_tm1_t - log_q_tm1
            
            
            
            def stratified_indices(key, w, K):
                # w must be normalized to sum=1
                cdf = jnp.cumsum(w)
                u0 = jax.random.uniform(key) / K
                u  = u0 + jnp.arange(K) / K
                return jnp.searchsorted(cdf, u, side="right")
            
            

            if mcmc:
                
                backwd_sampler = blackjax.irmh(logdensity_fn=lambda i: log_w_t[i], 
                                            proposal_distribution=lambda key: jax.random.choice(key, a=num_samples))


                def _backwd_sample_step(state, x):
                    step_nb, key = x

                    def _init(state, key):
                        return backwd_sampler.init(jax.random.choice(key, a=num_samples))
                    def _step(state, key):
                        return backwd_sampler.step(key, state)[0]
                    
                    new_state = jax.lax.cond(step_nb > 0, _step, _init, state, key)
                    return new_state, new_state.position
                    
                backwd_indices = jax.lax.scan(_backwd_sample_step, 
                                            init=backwd_sampler.init(0), 
                                            xs=(jnp.arange(3), 
                                                jax.random.split(key_resampling, 3)))[1][1:]
                
            else:
                
                
                
                
                num_back_idx = 50

                
                
                backwd_indices = jax.random.choice(key_resampling, 
                                                   a=num_samples, 
                                                   p=normalizer(log_w_t), 
                                                   shape=(num_back_idx,))
                
#                backwd_indices = stratified_indices(key_resampling,normalizer(log_w_t),num_back_idx)
            
            
            
            
            
            # Particles and H with backwards indices
            # Shape : backwdxdim
            sub_x_tm1 = x_tm1[backwd_indices]
            sub_H_tm1 = H_tm1[backwd_indices]
            sub_log_q_tm1_t = log_q_tm1_t[backwd_indices]
                
            # Joint density l_t dim: 
            sub_h_t = _vmaped_h(sub_x_tm1, 
                                sub_log_q_tm1_t)
            

            
            
            

            H_t = jnp.mean(sub_H_tm1 + sub_h_t, axis=0)

            control_variate = H_t

            
            # grad for Parameter of var distribution dim: 
            sub_grad_log_q_tm1_t, _ = jax.vmap(jax.grad(_log_q_tm1_t, has_aux=True),
                                        in_axes=(None,0,None))(unformatted_phi_t, 
                                                            sub_x_tm1, 
                                                            x_t)
                                                               
            sub_grad_log_l_t =  jax.vmap(jax.grad(_log_l_t, has_aux=False,allow_int = True),
                                        in_axes=(None,0,None))(unformatted_theta_t, 
                                                            sub_x_tm1, 
                                                            x_t)    
              
          
            
                                                              
            F_t = jax.tree.map(lambda F, sub_grad_log_backwd: jax.vmap(lambda F, H, h, grad_log_backwd: F + grad_log_backwd*(H+h-control_variate))(
                                                                    F[backwd_indices], 
                                                                    sub_H_tm1,
                                                                    sub_h_t, 
                                                                    sub_grad_log_backwd), 
                                                                    F_tm1, 
                                                                    sub_grad_log_q_tm1_t)
            F_t = jax.tree.map(lambda x: jnp.mean(x, axis=0), F_t)
            
            
            
            G_t = jax.tree.map(lambda G,sub_grad_log_l: jax.vmap(lambda G,grad_log_l: G + grad_log_l)(
                                                            G[backwd_indices],
                                                            sub_grad_log_l),
                                                            G_tm1,
                                                            sub_grad_log_l_t)

            G_t = jax.tree.map(lambda x: jnp.mean(x, axis=0), G_t)
            
            
            
            B_t = jax.tree.map(lambda Gprev: jnp.mean(Gprev, axis=0), G_tm1)

        return F_t, H_t,G_t,B_t, x_t, log_q_t, grad_log_q_t, base_s_t, s_t, params_q_t, tree_get_idx(0,params_q_tm1_t)

    F_t, H_t,G_t,B_t, x_t, log_q_t, grad_log_q_t, base_s_t, s_t, params_q_t, params_q_tm1_t = jax.vmap(update)(jax.random.split(key_t, num_samples))

    carry_t = {'stats':{'F':F_t, 'H':H_t,'G':G_t,'B':B_t},
            'base_s':tree_get_idx(0, base_s_t), 
            's':tree_get_idx(0, s_t),
            'x':x_t,
            'log_q':log_q_t,
            'grad_log_q':grad_log_q_t}
    
    

    return carry_t, (tree_get_idx(0,params_q_t), tree_get_idx(0, params_q_tm1_t))





def postprocess_elbo_score_gradients(carry, 
                                     **kwargs):

        H_T = carry['stats']['H']
        log_q_T = carry['log_q']
        F_T = carry['stats']['F']
        G_T = carry['stats']['G']
        B_T = carry['stats']['B']
        
        grad_log_q_T = carry['grad_log_q']
        
        elbo = jnp.mean(H_T - log_q_T, axis=0)


        moving_average_H = jax.lax.stop_gradient(elbo)
#        moving_average_H = jax.lax.stop_gradient(jnp.mean(H_T, axis=0))
        # I think we need this line to ensure unbiase 
        
#        moving_average_H = jax.lax.stop_gradient(elbo)



        
        grad = jax.tree.map(lambda grad_log_q, F: \
                        jnp.mean(jax.vmap(lambda a,b,c: (a-moving_average_H)*b + c)
                                (H_T, grad_log_q, F), 
                                axis=0), grad_log_q_T, F_T)
            
        grad_model = jax.tree.map(lambda G: jnp.mean(G,axis=0),G_T)
        
        # grad_model_B = tree.map(lambda B: jnp.mean(B,axis=0),B_T)
        # grad_model_diff = tree.map(lambda A,B: A - B,
        #                           grad_model, 
        #                           grad_model_B)

        return elbo, grad,grad_model
 








OnlineELBOScoreGradients = lambda p, q, num_samples, **options: OnlineVariationalAdditiveSmoothing(
                                                                p, 
                                                                q, 
                                                                online_elbo_functional,
                                                                init_carry_fn=init_carry_elbo_score_gradients, 
                                                                init_fn=init_elbo_score_gradients,
                                                                update_fn=update_elbo_score_gradients,
                                                                preprocess_fn=preprocess_for_bptt,
                                                                postprocess_fn=postprocess_elbo_score_gradients,
                                                                num_samples=num_samples, 
                                                                **options)



