import os
import numpy as np
import mpmath as mp
from scipy.integrate import quad
from scipy.optimize import brentq
from scipy.special import erf
import argparse

CDIR = os.path.dirname(os.path.realpath(__file__))
RESULTS = os.path.join(CDIR, 'result')
os.makedirs(RESULTS, exist_ok=True)

np_exp = np.frompyfunc(mp.exp, 1, 1)
np_log = np.frompyfunc(mp.log, 1, 1)
# np_exp = np_exp
np_sech = np.frompyfunc(mp.sech, 1, 1)
np_tanh = np.frompyfunc(mp.tanh, 1, 1)
np_arctan = np.frompyfunc(mp.atan, 1, 1)



# --------------- Swish --------------------

def cswish(a, z):
    return z / (1 + np_exp(-z / a))


def dcswish(a, z):
    return (cswish(a, z)) * (1.0 / z + cswish(a, z) * np_exp(-(z / a)) / (a * z))


def d2cswish(a, z):
    return (np_sech(z / (2 * a)) ** 2 * (2 * a - z * np_tanh(z / (2 * a)))) / (4 * a ** 2)


# --------------- Gelu ---------------------

def cgelu(a, z):
    return z / 2.0 * (1 + (erf(z / a / np.sqrt(2))))


def dcgelu(a, z):
    return z / a / np.sqrt(2.0 * np.pi) * np_exp(-z ** 2 / 2.0 / a ** 2) + 0.5 * (1 + (erf(z / a / np.sqrt(2))))


def d2cgelu(a, z):
    return (np_exp(-(z ** 2 / (2 * a ** 2))) * (2 * a ** 2 - z ** 2)) / (np.sqrt(2.0 * np.pi) * a ** 3)


# --------------- Gumbel --------------------

def cgumbel(a, z):
    return z * np_exp(-np_exp(-z / a))


def dcgumbel(a, z):
    return np_exp(-np_exp(-(z / a))) + np_exp(-np_exp(-(z / a)) - z / a) * z / a


def d2cgumbel(a, z):
    return (np_exp(-np_exp(-(z / a)) - (2 * z) / a) * (np_exp(z / a) * (2 * a - z) + z)) / a ** 2


# --------------- Guderman --------------------

def cguderman(T, z):
    return z * (1 / 2 + 2 / np.pi * np_arctan(np_tanh(z / T)))


def dcguderman(T, z):
    return 1 / 2 + 2 * np_arctan(np_tanh(z / T)) / np.pi + (2 * z * 1 / np.cosh(z / T) ** 2) / np.pi / T / (
            1 + np_tanh(z / T) ** 2)


def d2cguderman(T, z):
    return (4 * 1 / np.cosh(z / T) ** 2 * (
            T - z * (1 + 1 / np.cosh(z / T) ** 2) * np_tanh(z / T) + T * np_tanh(z / T) ** 2 - z * np_tanh(z / T) ** 3)) \
           / (T ** 2 * np.pi * (1 + np_tanh(z / T) ** 2) ** 2)


# --------------- Algebaric ---------------------

def calgebraic(T, z):
    return z * (0.5 * (1 + (z / T) / np.sqrt(1 + (z / T) ** 2)))


def dcalgebraic(T, z):
    return 0.5 + (T * np.sqrt(1 + z ** 2 / T ** 2) * (2 * T ** 2 * z + z ** 3)) / (2 * (T ** 2 + z ** 2) ** 2)


def d2calgebraic(T, z):
    return (2 * T ** 3 - T * z ** 2) / (2 * (T ** 2 + z ** 2) ** 2 * np.sqrt(1 + z ** 2 / T ** 2))


# --------------- Exponential ---------------------

def exponential(T, z):
    return np_exp(z / T)


def dexponential(T, z):
    return np_exp(z / T) / T


def d2exponential(T, z):
    return np_exp(z / T) / T ** 2


# --------------- ELU ---------------------

def elu(T, z):
    heaviside = z > 0
    # if z > 0:
    #     activation = z / T
    # else:
    #     activation = np_exp(z / T) - 1

    activation = z / T * heaviside + (np_exp(z / T) - 1) * (1 - heaviside)
    return activation


def delu(T, z):
    heaviside = z > 0

    # if z > 0:
    #     dactivation = 1 / T
    # else:
    #     dactivation = np_exp(z / T) / T

    dactivation = 1 / T * heaviside + (np_exp(z / T) / T) * (1 - heaviside)

    return dactivation


def d2elu(T, z):
    heaviside = z > 0
    # if z > 0:
    #     d2activation = 0
    # else:
    #     d2activation = np_exp(z / T) / T ** 2

    d2activation = (np_exp(z / T) / T ** 2) * (1 - heaviside)

    return d2activation


# --------------- Softplus ---------------------

def softplus(T, z):
    return T * np_log(np_exp(z / T) + 1)


def dsoftplus(T, z):
    return np_exp(z / T) / (np_exp(z / T) + 1)


def d2softplus(T, z):
    return np_exp(z / T) / (np_exp(z / T) + 1) ** 2 / T


# --------------- Mish ---------------------

def mish(T, z):
    return z * np_tanh(np_log(np_exp(z / T) + 1))


def dmish(T, z):
    return (np_exp(z / T)*(4 * np_exp(2 * z / T) + np_exp(3 * z / T)
                          + 4 * (1 + z / T) + np_exp(z / T)*(6 + 4 * z / T))) / (
                   2 + 2 * np_exp(z / T) + np_exp(2 * z / T)) ** 2


def d2mish(T, z):
    return -4 * np_exp(z / T) * (
            3 * np_exp(2 * z / T) * (-2 + z / T) + 2 * np_exp(3 * z / T) * (-1 + z / T) - 2 * (2 + z / T)
            - 2 * np_exp(z / T) * (4 + z / T)) / (2 + 2 * np_exp(z / T) + np_exp(2 * z / T)) ** 3 / T


# ---------------------------------------------------------------------------------------------------

acts_derivatives = {
    'gelu': {'f': cgelu, 'df': dcgelu, 'd2f': d2cgelu},
    'swish': {'f': cswish, 'df': dcswish, 'd2f': d2cswish},
    'gumbellu': {'f': cgumbel, 'df': dcgumbel, 'd2f': d2cgumbel},
    'gudermanlu': {'f': cguderman, 'df': dcguderman, 'd2f': d2cguderman},
    'algebraiclu': {'f': calgebraic, 'df': dcalgebraic, 'd2f': d2calgebraic},
    'mish': {'f': mish, 'df': dmish, 'd2f': d2mish},
    # 'softplus': {'f': softplus, 'df': dsoftplus, 'd2f': d2softplus},
    # 'elu': {'f': elu, 'df': delu, 'd2f': d2elu},
    # 'exp': {'f': exponential, 'df': dexponential, 'd2f': d2exponential},
}


def gK(K, a, f):
    integrand = lambda z: np_exp(-z ** 2 / 2 / K) * f(a, z) ** 2 / np.sqrt(2.0 * np.pi * K)
    term = quad(integrand, -np.infty, np.infty, limit=10000, epsabs=1.0e-14, epsrel=1.0e-14)[0]
    return term


def Cw(K, df):
    integrand = lambda z: np_exp(-z ** 2 / 2 / K) * df(1, z) ** 2 / np.sqrt(2.0 * np.pi * K)
    term = quad(integrand, -np.infty, np.infty, limit=10000, epsabs=1.0e-14, epsrel=1.0e-14)[0]
    return 1.0 / term


def Cb(K, Cw, f):
    term = gK(K, 1, f)
    return K - term * Cw


def fintK(K, f, d2f):
    integrand = lambda z: np_exp(-z ** 2 / 2 / K) * f(1, z) * d2f(1, z) / np.sqrt(2.0 * np.pi * K)
    result = quad(integrand, -np.infty, np.infty, limit=10000, epsabs=1.0e-14, epsrel=1.0e-14)[0]
    return result


def findK(f, d2f):
    function = lambda k: fintK(k, f, d2f)
    ksol = brentq(function, 0.001, 1000.0)
    return ksol


def layer_update(k, Cb, Cw, a, f):
    return Cw * gK(k * a ** 2, a, f) + Cb * a ** 2 - k * a ** 2


def parse_args():
    """Parse args."""
    # Initialize the command line parser
    parser = argparse.ArgumentParser()
    # Read command line argument
    parser.add_argument('--act', default='all', type=str, help='activation name')
    parser.add_argument('--cw', default=-1, type=float, help='Cw')
    parser.add_argument('--lowlim', default=0.1, type=float, help='low limit for krange')
    parser.add_argument('--uplim', default=1.5, type=float, help='upper limit for krange')
    parser.add_argument('--npoints', default=551, type=float, help='number of points for krange')
    args = parser.parse_args()

    return args


def main(args):
    """Main."""
    act = args.act
    ksol = None
    cb = None
    cw = args.cw
    f = acts_derivatives[act]['f']
    df = acts_derivatives[act]['df']
    d2f = acts_derivatives[act]['d2f']
    ksol = findK(f, d2f)

    if cw < 0:
        cw = Cw(ksol, df)
    cb = Cb(ksol, cw, f)
    print('parameters found')
    klist = np.linspace(args.lowlim * ksol, args.uplim * ksol, args.npoints)
    # T = 0.1
    # lu01 = [layer_update(k, cb, cw, T, f) for k in klist]
    # print("T=0.1 complete")
    T = 1.0
    lu1 = [layer_update(k, cb, cw, T, f) for k in klist]
    print("T=1 complete")
    print(ksol, cb, cw)
    # np.savez("params_" + act + "_cw_" + str(cw), krange=klist, cw=cw, cb=cb, kstar=ksol, lu01=lu01, lu1=lu1)
    np.savez('c' + act + "_params_v2", krange=klist, cw=cw, cb=cb, kstar=ksol, lu1=lu1)


def get_all_initializations():
    inits_path = os.path.join(RESULTS, 'critical_inits.h5')

    import pandas as pd
    if not os.path.exists(inits_path):
        columns = ['act', 'cw', 'cb', 'kstar']
        data = pd.DataFrame(columns=columns)
        for act in acts_derivatives.keys():
            print(act)
            f = acts_derivatives[act]['f']
            df = acts_derivatives[act]['df']
            d2f = acts_derivatives[act]['d2f']
            ksol = findK(f, d2f)
            cw = Cw(ksol, df)
            cb = Cb(ksol, cw, f)
            small_df = pd.DataFrame([{'act': act, 'cw': cw, 'cb': cb, 'kstar': ksol}])
            data = data.append(small_df)

        data.to_hdf(inits_path, key='df', mode='w')
    else:
        data = pd.read_hdf(inits_path, 'df')  # load it

    # print(data.to_string())
    return data


def plot_acts():
    import matplotlib.pyplot as plt
    fig, axs = plt.subplots(1, len(acts_derivatives.keys()), gridspec_kw={'wspace': .0, 'hspace': 0.}, figsize=(4, 4))

    x = np.linspace(-10, 10, 100)
    for i, k in enumerate(acts_derivatives.keys()):
        print(k)
        fs = acts_derivatives[k]  # {'f': cgelu, 'df': dcgelu, 'd2f': d2cgelu}
        axs[i].plot(x, fs['f'](1., x), color='r')
        axs[i].plot(x, fs['df'](1., x), color='g')
        axs[i].plot(x, fs['d2f'](1., x), color='b')
        axs[i].set_title(k)

    plt.show()


if __name__ == '__main__':
    """Entry point."""

    args = parse_args()

    if args.act == 'all':
        get_all_initializations()

    elif args.act == 'plot':
        plot_acts()
    else:
        main(args)
