"""
Contains implementation of a MLP, i.e., a fully connected model.

"""
import jax.numpy as jnp
from jax import random

# ------Copied from Jax documentation-----
def random_layer_params(m: int, n: int, key, scale: float = 1e-1):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
# ----------Copy ends----------------------

def mlp(activation): 
    def model(params, inpt):
        hidden = inpt
        for w, b in params[:-1]:
            outputs = jnp.dot(w, hidden) + b
            hidden = activation(outputs)
  
        final_w, final_b = params[-1]
        return jnp.dot(final_w, hidden) + final_b
    return model


def apply_periodic_embedding(x, periodic_dims, periods):
    """
    Apply sine and cosine periodic embeddings to specific dimensions of the input.
    
    Args:
    x: Input data, expected shape (batch_size, features).
    periodic_dims: List of dimension indices to which to apply the embedding.
    periods: List of periods corresponding to the periodic dimensions.
    
    Returns:
    Augmented feature matrix.
    """

    if len(periodic_dims) != len(periods):
        raise ValueError("Each periodic dimension must have a corresponding period value.")

    original_features = [x[i] for i in range(x.shape[0]) if i not in periodic_dims]
    periodic_features = [x[i] for i in periodic_dims]

    sin_cos_features = []
    for idx, feature in enumerate(periodic_features):
        frequency = 2 * jnp.pi / periods[idx]
        sin_feature = jnp.sin(feature * frequency)
        cos_feature = jnp.cos(feature * frequency)
        sin_cos_features.extend([sin_feature, cos_feature])

    augmented_x = jnp.reshape(jnp.array([original_features+sin_cos_features]),(-1,))
    
    return augmented_x



def mlp_with_periodic_embedding(activation, periodic_dims, periods):
    def model(params, inpt):
        # Apply the periodic embedding to the specified dimensions with their corresponding periods
        augmented_input = apply_periodic_embedding(inpt, periodic_dims, periods)

       
        hidden = augmented_input
        for w, b in params[:-1]:
            outputs = jnp.dot(w, hidden) + b 
            hidden = activation(outputs)
  
        final_w, final_b = params[-1]
        logits = jnp.dot(final_w, hidden) + final_b  
        return logits
    return model



