"""
Contains implementation of a MLP, i.e., a fully connected model.

"""
import jax.numpy as jnp
from jax import random
from jax import jacfwd,jacrev
# ------Copied from Jax documentation-----
def random_layer_params(m: int, n: int, key, scale: float = 1e-1):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# # Initialize all layers for a fully-connected neural network with sizes "sizes"
# def init_params(sizes, key):
#   keys = random.split(key, len(sizes))
#   return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
# # ----------Copy ends----------------------

def mlp(activation): 
    def model(params, inpt):
        hidden = inpt
        for w, b in params[:-1]:
            outputs = jnp.dot(w, hidden) + b
            hidden = activation(outputs)
  
        final_w, final_b = params[-1]
        return jnp.dot(final_w, hidden) + final_b
    return model


def apply_periodic_embedding(x, periodic_dims, periods):
    """
    Apply sine and cosine periodic embeddings to specific dimensions of the input.
    
    Args:
    x: Input data, expected shape (batch_size, features).
    periodic_dims: List of dimension indices to which to apply the embedding.
    periods: List of periods corresponding to the periodic dimensions.
    
    Returns:
    Augmented feature matrix.
    """

    if len(periodic_dims) != len(periods):
        raise ValueError("Each periodic dimension must have a corresponding period value.")

    original_features = [x[i] for i in range(x.shape[0]) if i not in periodic_dims]
    periodic_features = [x[i] for i in periodic_dims]

    sin_cos_features = []
    for idx, feature in enumerate(periodic_features):
        frequency = 2 * jnp.pi / periods[idx]
        sin_feature = jnp.sin(feature * frequency)
        cos_feature = jnp.cos(feature * frequency)
        sin_cos_features.extend([sin_feature, cos_feature])

    augmented_x = jnp.reshape(jnp.array([original_features+sin_cos_features]),(-1,))
    
    return augmented_x



def mlp_with_periodic_embedding(activation, periodic_dims, periods):
    def model(params, inpt):
        # Apply the periodic embedding to the specified dimensions with their corresponding periods
        augmented_input = apply_periodic_embedding(inpt, periodic_dims, periods)

       
        hidden = augmented_input
        for w, b in params[:-1]:
            outputs = jnp.dot(w, hidden) + b 
            hidden = activation(outputs)
  
        final_w, final_b = params[-1]
        logits = jnp.dot(final_w, hidden) + final_b  
        return logits
    return model


import jax.nn.initializers as initializers
def glorot_normal_layer_params(m: int, n: int, key):
    initializer = initializers.glorot_normal()
    w_shape = (n, m)
    b_shape = (n,)
    return initializer(key, w_shape), jnp.zeros(b_shape)

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [glorot_normal_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
def divn(F):
    B = jacfwd(F)
    return lambda x: jnp.trace(B(x),axis1=-2,axis2=-1)
class DivFreeImplicit(object):
    def __init__(self,network):
        self.network = network
        
    def __call__(self,params,tx):
        t = tx[0:1]
        x = tx[1:]        
        def A(x):
            
            jac_u = jacfwd(lambda xi: self.network(params, jnp.concatenate([t, xi])))(x)

            return jac_u - jac_u.T
        return divn(A)(x)    
