#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import jax
import jax.numpy as jnp
from jax import jit
import numpy as np
from tqdm import tqdm

step_size = 1e-3

seed = 0
num_epochs = 10000

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = jax.random.split(key)
  return scale * jax.random.normal(w_key, (n, m)), scale * jax.random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = jax.random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]


def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  w,b = params[0]
  activations = w @ activations# + b[:,jnp.newaxis]
  for w, b in params[1:-1]:
    #outputs = jnp.dot(w, activations) + b
    outputs = w @ activations + b[:,jnp.newaxis]
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  pred = final_w @ activations + final_b[:,jnp.newaxis]
  return pred
  #return logits - logsumexp(logits)

@jit
def get_nll(params, X, y):
    pred = predict(params, X.T)
    nll = 0.5 * jnp.mean(jnp.square(pred-y))
    return nll

grad = jax.grad(get_nll)

def update_raw(params, X, y):
  grads  = grad(params, X, y)
  new_params = [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
  return new_params

def train_net(X, y):
    X = (X-np.mean(X,axis=0)[np.newaxis,:])/np.std(X,axis=0)[np.newaxis,:]
    y = np.array(y)

    layer_sizes = [X.shape[1], 1, 512, 1]

    params = init_network_params(layer_sizes, jax.random.PRNGKey(seed))

    update = jax.jit(update_raw)

    costs = np.zeros(num_epochs)
    for i in tqdm(range(num_epochs)):
        cost = get_nll(params, X, y)
        costs[i] = cost
        params = update(params, X, y)

    return params, costs
