from functools import partial
from typing import Optional

import chex
import jax
import numpy as np
import jax.numpy as jnp
from flax import struct

from src.utils import *


class RBFKernel():
        
    
    def __init__(self,use_ard,ard_num_dims,lengthscale_prior,
                 *args,**kwargs):
        
        if use_ard == False :
            ard_num_dims = 1
        
        self.ard_num_dims = ard_num_dims
        
      

        self.state = None
        
        self.l_b = {
          
            "lengthscales":lengthscale_prior.a,
            "outputscale":0., ## placeholder
            "noise":0., ## placeholder
        }

        self.u_b = {
            "lengthscales":lengthscale_prior.b,
            "outputscale":0., ## placeholder
            "noise":0., ## placeholder
        }
        
        
        self.prior_mean =jax.tree_util.tree_map(lambda l_b,u_b : (l_b+u_b)/2,self.l_b,self.u_b)
        
        ### We initialise parameters as such just to make sure mpd does not move too much in the start
        self.init_params= {
            "lengthscales" : self.prior_mean["lengthscales"]*jnp.ones(ard_num_dims),
            ### Just placeholder values they are modified at initial iteration
            "outputscale" : jnp.array(0.02),
            "noise" : jnp.array(0.01),
            }

    
    
    def _prepare_inputs(self,kernel_state,x1,x2):
        
        return x1,x2
    
    def _forward_inputs_hps(self,kernel_params,kernel_state,
                 x1,x2):
        
        
        n = x1.shape[0]
        
        # x1 = jnp.divide(x1,kernel_params["lengthscales"])
        # x2 = jnp.divide(x2,kernel_params["lengthscales"])
        
        
        x1 = jnp.divide(x1,kernel_params["lengthscales"]*jnp.sqrt(self.ard_num_dims))
        x2 = jnp.divide(x2,kernel_params["lengthscales"]*jnp.sqrt(self.ard_num_dims))
        
        rslt = (kernel_params["outputscale"]**2) * rbf(x1,x2)
        rslt += (kernel_params["noise"]**2) * is_equal(x1,x2)
        
        return rslt
        
    
    def __call__(self,kernel_params,kernel_state,
                 x1,x2):
        
        # Normalize by number of dims
        
        
        x1,x2 = self._prepare_inputs(kernel_state,x1,x2)
        rslt = self._forward_inputs_hps(kernel_params,kernel_state,x1,x2)
        
        
        return rslt



@struct.dataclass
class RBFActionState:

    
    obs_params:chex.Array
    local_states :chex.Array=None
    local_masks :chex.Array=None 
    



class RBFAction(RBFKernel):

  def __init__(self,use_ard,ard_num_dims,
               outputscale_prior,lengthscale_prior,noise_prior,
               mlp,param_reshaper,obs_params):

    
    super().__init__(use_ard,ard_num_dims,
               outputscale_prior,lengthscale_prior,noise_prior)
    
    self.state = RBFActionState(obs_params=obs_params)
    self.mlp = mlp

    ##############################################################
    
    
    ##apply_many_states (acotr_param,obs_params,states) ==> (n_satetes,act_dim)
    apply_many_states = jax.vmap(self.mlp, 
                                          in_axes=(None,None,0),out_axes=0)

    ##run_params(params,states,rng) ==> (n_params,n_states,act_dim)
    run_params = jax.vmap(apply_many_states,
                                          in_axes=(param_reshaper.vmap_dict,None,None)) 


    parallel_dist = jax.vmap(pdist_squareform,in_axes=(1,1),out_axes=-1)
    
    reshape = param_reshaper.reshape


    self.__dict__.update(locals())
    
  def _prepare_inputs(self,kernel_state,x1,x2):
      
    params1 = self.reshape(x1)
    params2 = self.reshape(x2)
    
 
    a1 = self.run_params(params1,kernel_state.obs_params,kernel_state.local_states) # [n_params,n_states,n_actions]? Or [n_params,n_actions,n_states]
    a2 = self.run_params(params2,kernel_state.obs_params,kernel_state.local_states)
        
    return a1,a2

  def _forward_inputs_hps(self,kernel_params,kernel_state,
                 a1,a2):
      
      
    local_masks = kernel_state.local_masks.squeeze()
    
    a1 = jnp.divide(a1,kernel_params["lengthscales"]) # [n_params,n_states,n_actions]
    a2 = jnp.divide(a2,kernel_params["lengthscales"])
    
    actionwise_dist = self.parallel_dist(a1,a2) ### [n_params,n_params,n_states]
    pdist = jnp.sum( (actionwise_dist*local_masks),axis=-1)/(local_masks.sum())
    K = jnp.exp(- 0.5 * pdist)
    
    K =  (kernel_params["outputscale"]**2)* K
    ### HOTFIX : We don't have access to x1,x2 use actions to compute equal
    K += (kernel_params["noise"]**2) * is_equal(a1.mean(axis=1),a2.mean(axis=1))
    
    return K
    
    

  @partial(jax.jit, static_argnums=(0,))
  def __call__(self,kernel_params,kernel_state,
               x1,x2):
    
    """ We separated params from other parts of state 
    to compute grads in GP class"""

    
    a1,a2 = self._prepare_inputs(kernel_state,x1,x2)

    K = self._forward_inputs_hps(kernel_params,kernel_state,a1,a2)

    return K
  
  
class RBFState(RBFAction):

  def __init__(self,use_ard,ard_num_dims,
               outputscale_prior,lengthscale_prior,noise_prior,
               mlp,param_reshaper,obs_params):

    
    super().__init__(use_ard,ard_num_dims,
               outputscale_prior,lengthscale_prior,noise_prior,
               mlp,param_reshaper,obs_params)
    

    

    ##############################################################
    
    
    ##apply_many_states (acotr_param,obs_params,states) ==> (act_dim,n_satetes)
    self.apply_many_states = jax.vmap(self.mlp, 
                                          in_axes=(None,None,0),out_axes=-1)

    self.parallel_dist = jax.vmap(pdist_squareform,in_axes=(2,2),out_axes=-1)
    
    self.run_params = jax.vmap(self.apply_many_states,
                                          in_axes=(param_reshaper.vmap_dict,None,None)) 

    
    


    self.__dict__.update(locals())
    
  def _prepare_inputs(self,kernel_state,x1,x2):
      
    params1 = self.reshape(x1)
    params2 = self.reshape(x2)
    
 
    a1 = self.run_params(params1,kernel_state.obs_params,kernel_state.local_states) # [n_params,n_states,n_actions]? Or [n_params,n_actions,n_states]
    a2 = self.run_params(params2,kernel_state.obs_params,kernel_state.local_states)
        
    return a1,a2

  def _forward_inputs_hps(self,kernel_params,kernel_state,
                 a1,a2):
      
      
    local_masks = kernel_state.local_masks.squeeze()
    
    a1 = jnp.divide(a1,kernel_params["lengthscales"]) # [n_params,n_actions,n_states]
    a2 = jnp.divide(a2,kernel_params["lengthscales"])
    
    actionwise_dist = self.parallel_dist(a1,a2) ### [n_params,n_params,n_states]
    pdist = jnp.sum( (actionwise_dist*local_masks),axis=-1)/(local_masks.sum())
    K = jnp.exp(- 0.5 * pdist)
    
    K =  (kernel_params["outputscale"]**2)* K
    ### HOTFIX : We don't have access to x1,x2 use actions to compute equal
    K += (kernel_params["noise"]**2) * is_equal(a1.mean(axis=1),a2.mean(axis=1))
    
    return K
    

  @partial(jax.jit, static_argnums=(0,))
  def __call__(self,kernel_params,kernel_state,
               x1,x2):
    
    """ We separated params from other parts of state 
    to compute grads in GP class"""

    
    a1,a2 = self._prepare_inputs(kernel_state,x1,x2)

    K = self._forward_inputs_hps(kernel_params,kernel_state,a1,a2)

    return K
  
  
    

      


  
    

      

