import numpy as np
import jax.numpy as jnp
from jax import nn
from jax import grad, jit, vmap
from jax import random
from functools import partial


def random_layer_params(n_in, n_out, key, bias=False, scale=0.0002):
    """
    Generates random layer parameters for a neural network layer.
    Args:
        n_in (int): The number of input units.
        n_out (int): The number of output units.
        key (jax.random.PRNGKey): The random key for parameter generation.
        bias (bool, optional): Whether to include bias. Defaults to False.
        scale (float, optional): The scaling factor for parameter initialization variance. Defaults to 0.0002.

    Returns:
        Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: The randomly generated layer parameters.
    """

    if not bias:
        return scale * random.normal(key, (n_in, n_out))
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n_in, n_out)), scale * random.normal(b_key, (n_out,))

def init_network_params(sizes, key, bias=False, scale=0.0002):
    """
    Initializes the parameters for a multi-layer neural network.

    Args:
        sizes (List[int]): The sizes of each layer in the network.
        key (jax.random.PRNGKey): The random key for parameter generation.
        bias (bool, optional): Whether to include bias. Defaults to False.
        shallow (bool, optional): Whether to use a shallow network. Defaults to True.
        scale (float, optional): The scaling factor for parameter initialization. Defaults to 0.0002.

    Returns:
        List[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]]: The initialized network parameters.
    """
    # for the shallow network
    if sizes[1] == 'auto':
        return [random_layer_params(sizes[0], sizes[2], key, bias, scale)]
    keys = random.split(key, len(sizes))
    param_list = []
    for i, hyp_params in enumerate(zip(sizes[:-1], sizes[1:], keys)):
        # no bias term in the first layer anyways cause of the input bias
        m, n, k = hyp_params
        if i == 0:
            param_list.append(random_layer_params(m, n, k, False, scale))
        else:
            param_list.append(random_layer_params(m, n, k, bias, scale))
    return param_list

def forward(params, input, bias):
    """
    Performs forward propagation through a multi-layer linear network.

    Args:
        params (List[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]]): The network parameters.
        input (jnp.ndarray): The input data.
        bias (bool, optional): Whether to include bias. Defaults to False.

    Returns:
        jnp.ndarray: The output of the neural network.
    """
    # per-example predictions
    activations = input
    for param_set in params[:-1]:
        w = param_set
        outputs = jnp.dot(activations, w)
        activations = outputs
    if bias:
        final_w = params[-1][0]
        final_b = params[-1][1]
        outputs = jnp.dot(activations, final_w) + final_b
    else:
        final_w = params[-1]
        outputs = jnp.dot(activations, final_w)
    
    return outputs

# make a version that can deal with batches of inputs
batched_forward = vmap(forward, in_axes=(None, 0, None))

def loss(params, inputs, targets, bias=False):
    """
    Calculates the squared loss between the predicted values and the target values.

    Args:
        params (array-like): The parameters of the model.
        inputs (array-like): The input data.
        targets (array-like): The target values.

    Returns:
        float: The mean squared loss.
    """
    predictions = batched_forward(params, inputs, bias)
    return 1/2 * np.sum((predictions - targets)** 2)

@partial(jit, static_argnames=['bias'])
def update(params, x, y, step_size, bias=False):
    """
    Calculates the accuracy of the model predictions given the input data and target labels.

    Args:
        params: The model parameters.
        x: The input data.
        y: The target labels.
        step_size: The step size for the gradient descent update.
        bias: Whether to include bias in the final layer of the 2 layer net.

    Returns:
        float: The accuracy of the model predictions.
    """
    grads = grad(loss)(params, x, y, bias)

    updated_params = []
    num_layers = len(params)

    for i, (param, grad_param) in enumerate(zip(params, grads)):
        if i == num_layers - 1 and bias:
            w, b = param
            dw, db = grad_param
            updated_params.append((w - step_size * dw, b - step_size * db))
        else:
            w = param
            dw = grad_param
            updated_params.append((w - step_size * dw))
    return updated_params

