import jax
from jax import numpy as jnp
import flax 
import optax
from functools import partial

@partial(jax.jit, donate_argnums=0)
def tree_transpose(list_of_trees):
  #function to transpose a pytree (list of trees -> tree of jnp.arrays)
  return jax.tree_map(lambda *xs: jnp.array(xs), *list_of_trees)

def recursive_dict_update(key, from_dict, target_dict):
    if len(key)==1:
        target_dict.update({key[0]: from_dict[key[0]]})
    else:
        if key[0] not in target_dict:
            target_dict[key[0]] = {}
        recursive_dict_update(key[1:], from_dict[key[0]], target_dict[key[0]])

def get_params_subset(params,subset_keys):
    # get a PyTree for a subset of parameters with given (upper level) keys
    if subset_keys is None:
        return params
    else:
        params_subset = {}
        for key in subset_keys:
            recursive_dict_update(key, params, params_subset)
        return params_subset

def replace_params_subset(params,subset_keys,params_subset):
    if subset_keys is None:
        return params
    else:
        _params = flax.core.frozen_dict.unfreeze(params)
        for key in subset_keys:
            #_params[key] = params_subset[key]
            recursive_dict_update(key, params_subset, _params)
        return _params

def mean_pred(params_subset, state, batch, params_keys, aggregate =False):
    # mean batch prediction as a function of a parameters subset
      params = replace_params_subset(state.params,params_keys,params_subset)
      if not aggregate :
        return jnp.mean(state.apply_fn({'params': params, 
                            'batch_stats': state.batch_stats}, 
                                batch['image'], 
                                train = False,
                                mutable=False),axis=0)
      else:
        return jnp.mean(state.apply_fn({'params': params, 
                            'batch_stats': state.batch_stats}, 
                                batch['image'], 
                                train = False,
                                mutable=False))  

def pred(params_subset, state, x, params_keys, aggregate =False):
      # batch prediction as a function of a parameters subset
      params = replace_params_subset(state.params,params_keys,params_subset)
      if not aggregate :
        return state.apply_fn({'params': params, 
                            'batch_stats': state.batch_stats}, 
                            x, 
                            train = False,
                            mutable=False)
      else:
        return jnp.sum(state.apply_fn({'params': params, 
                            'batch_stats': state.batch_stats}, 
                            x, 
                            train = False,
                            mutable=False), axis=1)
      
def batch_loss(params_subset, state, batch, params_keys):
    '''
    Loss computed for each sample in a batch.
    '''
    params = replace_params_subset(state.params,params_keys,params_subset)
    logits = state.apply_fn({'params': params, 
                        'batch_stats': state.batch_stats}, 
                        batch['image'], 
                        train = False,
                        mutable=False)
    
    loss = optax.softmax_cross_entropy(logits=logits, labels=batch['label'])
    return loss
    
@partial(jax.jit, static_argnames=['params_keys','aggregate'])
def mean_batch_jacobian(state, batch, params_keys, aggregate ): 
   # a function to compute average jacobian of a batch
   mean_pred_func = partial(mean_pred, state=state, batch=batch, params_keys=params_keys,aggregate =aggregate )
   params_subset = get_params_subset(state.params,params_keys)
   jac = jax.jacobian(mean_pred_func)(params_subset)
   return jac

@partial(jax.jit, static_argnames=['params_keys','aggregate'])
def jacobian(state, x, params_keys, aggregate ): 
   # a function to compute jacobian of a batch
   pred_func = partial(pred, state=state, x=x, params_keys=params_keys,aggregate =aggregate )
   params_subset = get_params_subset(state.params,params_keys)
   jac = jax.jacobian(pred_func)(params_subset)
   return jac # result is (out_head,...)

@partial(jax.jit, donate_argnums=0)
def replace(x,y,c):
    return x.at[c,...].set(y)

@jax.jit
def features(state, x):
    '''
    Compute last-layer features for a model with given state.
    Requires the model to have module 'features'.
    '''
    return state.apply_fn({'params': state.params, 
                           'batch_stats': state.batch_stats}, 
                            x, 
                            method=lambda module, x: module.features(x, train=False))

@partial(jax.jit, static_argnames=['params_keys'])
def loss_jacobian(state, batch, params_keys):
    batch_loss_func = partial(batch_loss, state=state, batch=batch, params_keys=params_keys)
    params_subset = get_params_subset(state.params,params_keys)
    jac = jax.jacobian(batch_loss_func)(params_subset)
    return jac 


@partial(jax.jit, static_argnames=['aggregate'])
def jac_contract(x,y,aggregate ):
    #function to contract jacobian leafs indexed as (out_head,...)
    #over all the axes except out_head
    num_axes = len(x.shape)
    if x.shape != y.shape:
        raise ValueError
    if not aggregate :
        axes = range(1,num_axes)
    else:
        axes = range(0,num_axes)
    return jnp.sum(x*y,axis=axes) # results is (out_head,) 

def jac_dot(g1,g2,aggregate ):
    #function to get a dot product of two jacobians indexed as (out_head,...)
    #over all the axes except out_head
    _jac_contract = partial(jac_contract,aggregate =aggregate )
    return jax.tree_util.tree_reduce(jnp.add,jax.tree_map(_jac_contract,g1,g2)) # results is (out_head,...) 

@partial(jax.jit, static_argnames=['aggregate'])
def jac_broadcast(x,y,aggregate ):
    #function to reshape x of shape (out_head,)
    #into a shape broadcastable with y of shape (out_head,...)
    num_axes = len(y.shape)
    if not aggregate :
        axes = range(1,num_axes)
    else:
        axes = range(0,num_axes)
    return jnp.expand_dims(x,axes) 

@partial(jax.jit, static_argnames=['aggregate'])
def block_ntk(jac_by_class, aggregate ): 
    # output is (num_classes,num_classes,out_heads)
    _jac_dot = partial(jac_dot, aggregate =aggregate )
    return jax.vmap(jax.vmap(_jac_dot,(None,0)),(0,None))(jac_by_class,jac_by_class)

@partial(jax.jit, static_argnames=['aggregate'])
def block_svd(jac_by_class,aggregate ):
    # output is (out_heads,num_classes,num_classes)
    ntk = block_ntk(jac_by_class,aggregate ).T
    return jnp.linalg.eigh(ntk)[0][...,::-1], jnp.linalg.eigh(ntk)[1][...,::-1]
    
def normalize(vecs,num_vecs,aggregate):
    _jac_dot = partial(jac_dot,aggregate=aggregate)
    _jac_broadcast = partial(jac_broadcast,aggregate=aggregate)
        
    for i in range(num_vecs): #number of eigenvectors 
        u = jax.tree_map(lambda x: x[i], vecs) # i-th eigenvector for all heads (out_head,...)
        u_norm = jnp.sqrt(_jac_dot(u,u))
        e = jax.tree_util.tree_map(jax.jit(lambda x: x/_jac_broadcast(u_norm,x),donate_argnums=0), u)
        vecs = jax.tree_util.tree_map(partial(replace,c=i), vecs, e)
    
    return vecs