from jax import random, value_and_grad, jit, vmap
import jax.numpy as jnp
import logging
from utils import vectorize,unvectorize
from optimizers import run_optex, run_standard, run_line_search, run_benchmark, tuning_mattern
import numpy as np
# from optimizers import run_optex
logger = logging.getLogger()

def predict(params, X):
    # per-example predictions
    activations = X
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)

    final_w, final_b = params[-1]
    output = jnp.dot(final_w, activations) + final_b
    return output


def batch_func(predict_func):
    f = vmap(predict_func, in_axes=(None, 0))
    return f


def random_layer_params(m, n, key):
    w_key, b_key = random.split(key)
    return kaiming(w_key, m, n), jnp.zeros((n,))


def kaiming(key, m, n):
    return random.normal(key, (n, m)) * jnp.sqrt(2.0 / m)


# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    logger.info(f"Randomly initializing a network with layers {sizes}")
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]


def relu(x):
    return jnp.maximum(0, x)


def mse_loss(func, params, X, Y, Z):
    preds = func(params, X, Y)
    lo = jnp.mean(jnp.square(preds - Z))
    # if jnp.isnan(lo):
    #     raise ValueError(f"Loss went to nan, the predictions were {preds} and the target was {Z}")
    return lo


def update(method, opt_name ,func, params, X, Y, Z, opt_state, step_size=0.001, params_shape=None, grad_clip=1, num_parall =1 ,edim=-1,inter_results={}):
    
    datas = list(zip(X, Y, Z))
    # a=vectorize(params)
    # b=unvectorize(a, get_shapes(params))
    print(opt_name, step_size)
    
    loss, x, opt_state = eval("run_" + method)(
        lambda params, X, Y, Z:
            mse_loss(func, params, X, Y, Z),
        "optax."+opt_name,
        # "optax.sgd",
        step_size,
        vectorize(params),
        params_shape,
        1,
        num_parall,
        datas,
        opt_state,
        grad_clip,
        effective_dim= edim,
        inter_results = inter_results
    )
   #  loss, grads = value_and_grad(mse_loss, argnums=1)(func, params, X, Y)
   #  if jnp.isnan(grads[0][0]).any():
   #      raise ValueError(f"gradient went to nan, the inputs were {X} and the target was {Y}")
   #
   # grads = [
   #      (
   #       jnp.clip(dw, a_max=grad_clip, a_min=-grad_clip),
   #          jnp.clip(db, a_max=grad_clip, a_min=-grad_clip),
   #      )
   #      for (dw, db) in grads
   #   ]
    return loss, x, opt_state


def grad_descent(params, grads, step_size):
    return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
