from math import sqrt, pi
import numpy as np
from scipy.integrate import quad as integrate


def get_kth_hermite_coef(f, k):
    gaussian_pdf = lambda z: np.exp(-z ** 2 / 2) / np.sqrt(2 * pi)

    def integrand(z):
        o = f(z) * gaussian_pdf(z)
        if k == 0:
            pass
        elif k == 1:
            o *= z
        elif k == 2:
            o *= (z ** 2 - 1)
        else:
            raise NotImplementedError
        return o

    return integrate(integrand, -np.inf, np.inf)[0]


def gaussian_norm(f):
    return sqrt(get_kth_hermite_coef(lambda z: f(z) ** 2, k=0))


class ActivationFunction(object):
    def __call__(self, z):
        l0 = get_kth_hermite_coef(self.val, k=0)
        return self.val(z) - l0

    def val(self, z):
        raise NotImplementedError

    def grad(self, z):
        raise NotImplementedError

    @property
    def norm2(self):
        if hasattr(self, "_norm2"):
            return self._norm2
        return gaussian_norm(self) ** 2

    @property
    def norm2_grad(self):
        if hasattr(self, "_norm2_grad"):
            return self._norm2_grad
        return gaussian_norm(self.grad) ** 2

    @property
    def lambda1(self):
        if hasattr(self, "_lambda1"):
            return self._lambda1
        return get_kth_hermite_coef(self, k12)

    @property
    def lambda2(self):
        if hasattr(self, "_lambda2"):
            return self._lambda2
        return get_kth_hermite_coef(self, k=2)

    @property
    def lambda3(self):
        if hasattr(self, "_lambda3"):
            return self._lambda3
        return get_kth_hermite_coef(self.grad, k=2)


class QuadraticActivationFunction(ActivationFunction):
    def __init__(self):
        self._norm2 = 2
        self._norm2_grad = 4
        self._lambda1 = 0
        self._lambda2 = 2
        self._lambda3 = 0.

    def val(self, z):
        return z ** 2 - 1

    def grad(self, z):
        return 2 * z


class IdentityActivationFunction(ActivationFunction):
    def val(self, z):
        return z

    def grad(self, z):
        return np.ones_like(z)


class ReLUActivationFunction(ActivationFunction):
    def __init__(self):
        self._norm2 = (pi - 1) / (2 * pi)
        self._norm2_grad = 1 / 2
        self._lambda1 = 1 / 2
        self._lambda2 = 1 / sqrt(2 * pi)
        self._lambda3 = 0

    def val(self, z):
        return np.maximum(z, 0)

    def grad(self, z):
        return (z > 0) * 1.


class AbsValActivationFunction(ActivationFunction):
    def val(self, z):
        return np.abs(z)

    def grad(self, z):
        return np.sign(z)


class SineActivationFunction(ActivationFunction):
    def __call__(self, z):
        return np.sign(z)

    def grad(self, z):
        return np.cos(z)


class TanhActivationFunction(ActivationFunction):
    def val(self, z):
        return np.tanh(z)

    def grad(self, z):
        return 1 - np.tanh(z) ** 2


class SigmoidActivationFunction(ActivationFunction):
    def val(self, z):
        return np.reciprocal(1 + np.exp(-z))

    def grad(self, z):
        o = self(z)
        o *= 1 - o
        return o
