#!/usr/bin/env python
# -*- coding: utf-8 -*-

import jax
import jax.numpy as np
from jax import jvp, jit, vmap, grad
from functools import partial

import staix
from staix import elementwise
from initializers import rademacher
import utils
from utils import mean_f, vec_f


##########################
#  function space state  #
##########################


def _Yd_sY(net, params, s_a):
    Yd, sY = jvp(lambda params: net.apply(params, s_a), (params,), (params,))
    return Yd[0], sY[0]


def Yd_sY(net, params, inputs):
    return vmap(_Yd_sY, (None, None, 0))(net, params, inputs)


def _get_TrH(net, nw, s_a, sumi_Xi):
    ys = _get_ys(net, nw, s_a)
    sigmas = get_sigmas(net)
    dsigmas = [grad(sigma) for sigma in sigmas]
    ddsigmas = [grad(dsigma) for dsigma in dsigmas]
    tangent = [
        np.multiply(vmap(ddsigma)(y), xi)
        for y, ddsigma, xi in zip(ys[:-1], ddsigmas, sumi_Xi)
    ]
    zero_tangent = utils.transform(["Tree"], "Tree")(lambda x: np.zeros_like(x))(
        tangent
    )
    zero_nw = utils.transform(["Tree"], "Tree")(lambda x: np.zeros_like(x))(nw)
    f = lambda nw, placeholder: get_network_with_placeholder(net, placeholder).apply(
        nw, s_a
    )[0]
    _, TrH = jvp(f, (nw, zero_tangent), (zero_nw, tangent))
    return TrH


get_TrH = vmap(_get_TrH, (None, None, 0, 0))


###################
#  middle layers  #
###################
def _get_ys(net, nw, s_a):
    _, middle = net.apply_with_middle(nw, s_a)
    # ys[i-1]=y_i
    ys = middle[1::2]
    # adhoc change to reduce variance
    ys[-1] = vmap(get_sigmas(net)[-1])(ys[-2])
    return ys


get_ys = vmap(_get_ys, (None, None, 0))


def _get_xs_ys(net, nw, s_a):
    _, middle = net.apply_with_middle(nw, s_a)
    # xs[i]=x_i
    xs = middle[0::2]
    ys = middle[1::2]
    return xs, ys


get_xs_ys = vmap(_get_xs_ys, (None, None, 0))


def get_sigmas(net):
    # sigmas[i-1]=sigma_i
    sigmas = [lambda y: layer.apply((), y) for layer in net.serial_layers()[1::2]]
    return sigmas


def set_sigmas(net, sigmas):
    layers = list(net.serial_layers())
    for i, sigma in enumerate(sigmas):
        layers[2 * i + 1] = elementwise(sigma)
    return staix.serial(*layers)


def get_fc_layers(net):
    # fc_layers[i-1]=c_i W_i
    fc_layers = net.serial_layers()[0::2]
    return fc_layers


def get_fc_cs(fc_layers, xs):
    # cs[i-1]=c_i
    cs = [fc_layer.get_c(x.shape[-1]) for fc_layer, x in zip(fc_layers, xs)]
    return cs


def get_network_with_placeholder(net, placeholder):
    def controlled_nonlinear(sigma, p):
        return lambda x: sigma(x) + p

    sigmas = get_sigmas(net)
    return set_sigmas(
        net, [controlled_nonlinear(sigma, p) for sigma, p in zip(sigmas, placeholder)]
    )


######################
#  generalized NTKs  #
######################


def __get_sumi_Xi(net, xs, ys):
    d = net.d
    # sigmas[i-1]=sigma_i
    sigmas = get_sigmas(net)
    # dsigmas[i-1]=sigma_i'
    dsigmas = [grad(sigma) for sigma in sigmas]
    # fc_layers[i-1]=c_i W_i
    fc_layers = get_fc_layers(net)
    # cs[i-1]=c_i
    cs = get_fc_cs(fc_layers, xs)

    empirical_sum = lambda f, y: vmap(f)(y).sum()
    # K[k-1] = E[sigma_k'(y_k(s_a)) sigma_k'(y_k(s_a))]
    K = [
        cs[i] ** 2 * empirical_sum(lambda y: dsigmas[i - 1](y) ** 2, ys[i - 1])
        for i in range(1, d - 1)
    ]
    # Sigma[i-1] = c_i^2 E[x_{i-1}(s_a)^T x_{i-1}(s_a)]
    Sigma = [
        cs[i - 1] ** 2 * empirical_sum(lambda x: x ** 2, xs[i - 1]) for i in range(1, d)
    ]
    # sumi_Xi[j-1] = sum_{i=1}^j Xi[i,j]
    sumi_Xi = []
    for j in range(1, d):
        Xi_ijs = []
        for i in range(1, j + 1):
            Xi_ij = np.prod(np.array([K[k - 1] for k in range(i, j)] + [Sigma[i - 1]]))
            Xi_ijs.append(Xi_ij)
        sumi_Xi.append(np.sum(np.array(Xi_ijs)))
    return np.array(sumi_Xi)


def _get_sumi_Xi(net, nw, s_a):
    # ys[i-1]=y_i
    # xs[i]=x_i
    xs, ys = _get_xs_ys(net, nw, s_a)
    return __get_sumi_Xi(net, xs, ys)


get_sumi_Xi = vmap(_get_sumi_Xi, (None, None, 0))


def get_sumi_Xi_avg(net, subkeys, inputs, **kwargs):
    one_instance = lambda k: get_sumi_Xi(net, net.init(k, inputs.shape)[1], inputs)
    return mean_f(one_instance, subkeys, **kwargs)


#############################
#  NTKs by empirical value  #
#############################

# Compute empirically instead of by analytic form
# because the closed form formula is too tangled
# and computing by forward propagation is fast
def _empirical_sumi_Phi_Xi_ab(net, nw, sumi_Xi, s_a, s_b):
    g_b = grad(lambda nw: net.apply(nw, s_b)[0])(nw)
    _, sumi_Phi_Xi = jvp(lambda nw: _get_TrH(net, nw, s_a, sumi_Xi), (nw,), (g_b,))
    return sumi_Phi_Xi


def empirical_sumi_Phi_Xi(net, nw, sumi_Xi, inputs):
    return vmap(
        vmap(_empirical_sumi_Phi_Xi_ab, (None, None, None, None, 0)),
        (None, None, 0, 0, None),
    )(net, nw, sumi_Xi, inputs, inputs)


def empirical_sumi_Phi_Xi_avg(net, subkeys, sumi_Xi, inputs, **kwargs):
    one_instance = lambda k: empirical_sumi_Phi_Xi(
        net, net.init(k, inputs.shape)[1], sumi_Xi, inputs
    )
    return mean_f(one_instance, subkeys, **kwargs)


def _empirical_sumi_Theta_a_b(net, nw, s_a, s_b):
    g_b = grad(lambda nw: net.apply(nw, s_b)[0])(nw)
    _, sumi_Theta_a_b = jvp(lambda nw: _Yd_sY(net, nw, s_a)[1], (nw,), (g_b,))
    return sumi_Theta_a_b


def empirical_sumi_Theta(net, nw, inputs):
    return vmap(
        vmap(_empirical_sumi_Theta_a_b, (None, None, None, 0)),
        (None, None, 0, None),
    )(net, nw, inputs, inputs)


def empirical_sumi_Theta_avg(net, subkeys, inputs, **kwargs):
    one_instance = lambda k: empirical_sumi_Theta(
        net, net.init(k, inputs.shape)[1], inputs
    )
    return mean_f(one_instance, subkeys, **kwargs)


def _empirical_Thetad_a_b(net, nw, s_a, s_b):
    g_b = grad(lambda nw: net.apply(nw, s_b)[0])(nw)
    _, Thetad_a_b = jvp(lambda nw: net.apply(nw, s_a)[0], (nw,), (g_b,))
    return Thetad_a_b


def empirical_Thetad(net, nw, inputs):
    return vmap(
        vmap(_empirical_Thetad_a_b, (None, None, None, 0)),
        (None, None, 0, None),
    )(net, nw, inputs, inputs)


def empirical_Thetad_avg(net, subkeys, inputs, **kwargs):
    one_instance = lambda k: empirical_Thetad(net, net.init(k, inputs.shape)[1], inputs)
    return mean_f(one_instance, subkeys, **kwargs)


def _empirical_Thetad_a_a(net, nw, s_a):
    g = grad(lambda nw: net.apply(nw, s_a)[0])(nw)
    return sum([np.square(leaf).sum() for leaf in jax.tree_util.tree_flatten(g)[0]])


def empirical_diag_Thetad(net, nw, inputs, **kwargs):
    #  return vmap(lambda s_a: _empirical_Thetad_a_b(net, nw, s_a, s_a), inputs)
    f = lambda s_a: _empirical_Thetad_a_a(net, nw, s_a)
    return vec_f(f, inputs, **kwargs)


def empirical_diag_Thetad_avg(net, subkeys, inputs, inner_loop_args={}, **kwargs):
    one_instance = lambda k: empirical_diag_Thetad(
        net, net.init(k, inputs.shape)[1], inputs, **inner_loop_args
    )
    return mean_f(one_instance, subkeys, **kwargs)


###################################
#  stochastic Hassian estimation  #
###################################


def stochastic_H_est(f, theta, z):
    """Need E[z z^T] = I"""
    return jvp(lambda th: jvp(f, (th,), (z,))[1], (theta,), (z,))[1]


def Hutchinson_ests(net):
    def inner(params, inputs, subkeys, **kwargs):
        f = lambda s_a, k: stochastic_H_est(
            lambda p: net.apply(p, s_a)[0],
            params,
            net.init_any(k, s_a.shape, rademacher(dtype=np.float_))[1],
        )
        ff = lambda k: vmap(f, (0, None))(inputs, k)
        results = vec_f(ff, subkeys, **kwargs)
        return results.mean(0), results.std(0) / np.sqrt(subkeys.shape[0])

    return inner
