import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from scipy.stats import norm
from scipy.optimize import fsolve
from sympy import *
import fselu


class SeLUv2(nn.Module):
    def __init__(self, gamma_, fixpoint, epsilon_=0.05):
        """
        Implementation of SeLU in TPAMI paper
        :param gamma_: the std of weights
        :param fixpoint: the fixpoint of the pre-activations' second-order moment
        :param epsilon_: the epsilon
        """
        super(SeLUv2, self).__init__()
        self.g_ = gamma_
        self.fp = fixpoint
        self.e = epsilon_

        v = self.fp * self.g_
        phi_fn = lambda v, l_, a_: \
            (np.square(l_*a_)*np.exp(2.*v)*norm(0, np.sqrt(v)).cdf(-2.*v) + np.square(l_)/2.) * self.g_

        def func(i):
            l_, a_ = i[0], i[1]
            return [
                phi_fn(v, l_, a_) - (1+self.e),
                (np.square(l_*a_)*(np.exp(2.*v)*norm(0, np.sqrt(v)).cdf(-2.*v) - 2. * np.exp(v/2.)*norm(0, np.sqrt(v)).cdf(-v))
                 + 0.5*np.square(l_*a_) + 0.5*l_*l_*v)/self.fp - 1.
            ]

        [self.lamb_, self.alpha_] = fsolve(func, [1., 1.])
        print('lambda: %.6f' % self.lamb_)
        print('alpha: %.6f' % self.alpha_)

    def forward(self, input):
        return self.lamb_ * F.elu(input, self.alpha_)


class FuselSELUfn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha, beta, lambda_):
        y = fselu.fuse_lselu_f(x, alpha, beta, lambda_)
        ctx.save_for_backward(x, lambda_)
        ctx.alpha = alpha
        ctx.beta = beta
        return y
    
    @staticmethod
    def backward(ctx, grad_y):
        x, lambda_ = ctx.saved_tensors
        grad_x, grad_lambda = fselu.fuse_lselu_b(grad_y, x, ctx.alpha, ctx.beta, lambda_, True)
        return grad_x,  None, None, grad_lambda

fuse_lselu = FuselSELUfn.apply


class lSeLU(nn.Module):
    def __init__(self, q_in, epsilon):
        super(lSeLU, self).__init__()
        self.fixpoint = 1.
        self.q = q_in
        self.epsilon = epsilon
        
        l, a, b, x = symbols('l a b x', real=True)
        q = symbols('q', positive=True)
        px = exp(-x*x/2/q) / sqrt(2*pi*q)
        
        phiv2 = integrate(l**2 * px, (x, 0, +oo)) + integrate((l*a*exp(x) + l*b)**2 * px, (x, -oo, 0))
        q_hatv2 = integrate(l**2 * x**2 * px, (x, 0, +oo)) + integrate((l*a*exp(x) + l *b * x - l*a)**2 * px, (x, -oo, 0))
        slopev2 = diff(phiv2, q)
        
        self.alpha_ = 1.
        self.beta = 1.
        self.lambda_ = 1.
        
        # functions to calculate phi, q, and slope
        self.phi_fn = lambdify([l, a, b, q], phiv2, "numpy")
        self.q_fn = lambdify([l, a, b, q], q_hatv2, "numpy")
        self.slope_fn = lambdify([l, a, b, q], slopev2, "numpy")
        
        self.multiplier = nn.Parameter(torch.ones(1))
        
        # get the optimal value
        self.optimize(lmax=1.5, step=0.001)
        
        # set lambda
        self.multiplier.data.fill_(self.lambda_)
    
    def get_coefficient(self, i, l):
        a_, b_ = i[0], i[1]
        return [
            self.q * self.phi_fn(l, a_, b_, self.q) - (1 + self.epsilon),
            self.q_fn(l, a_, b_, self.q) / self.fixpoint - 1.
        ]
    
    def optimize(self, lmax, step):
        l_list = np.arange(1., lmax, step)
        slope_min = 1
        
        for l_ in l_list:
            [a_, b_], infodict, ier, mesg = fsolve(self.get_coefficient, [1., 1.], l_, full_output=True)
            slope_ = self.slope_fn(l_, a_, b_, self.q)
            error = np.max(np.abs(infodict['fvec']))
            # print("l: %.3f, slop: %.3f" % (l, slope))
            if (slope_ < slope_min) and (b_ > 0) and (error < self.epsilon):
                self.alpha_ = a_
                self.beta = b_
                self.lambda_ = l_
                slope_min = slope_
        
        print('lambda: %.9f, alpha: %.9f, beta: %.9f' % (self.lambda_, self.alpha_, self.beta))
    
    def forward(self, input_):
        # return self.multiplier * fuse_lselu(input_, self.alpha_, self.beta, self.lambda_)
        return fuse_lselu(input_, self.alpha_, self.beta, self.multiplier)



class FusesSELUfn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha, beta, lambda_):
        y = fselu.fuse_sselu_f(x, alpha, beta, lambda_)
        ctx.save_for_backward(x, lambda_)
        ctx.alpha = alpha
        ctx.beta = beta
        return y
    
    @staticmethod
    def backward(ctx, grad_y):
        x, lambda_ = ctx.saved_tensors
        grad_x, grad_lambda = fselu.fuse_sselu_b(grad_y, x, ctx.alpha, ctx.beta, lambda_, True)
        return grad_x,  None, None, grad_lambda

fuse_sselu = FusesSELUfn.apply



class sSeLU(nn.Module):
    def __init__(self, q_in, epsilon):
        super(sSeLU, self).__init__()
        self.fixpoint = 1.
        self.q = q_in
        self.epsilon = epsilon
        
        l, a, b, x = symbols('l a b x', real=True)
        q = symbols('q', positive=True)
        px = exp(-x*x/2/q) / sqrt(2*pi*q)
        
        phiv2 = integrate(l**2 * px, (x, 0, +oo)) + integrate(l**2*a**2*b**2*exp(2*b*x) * px, (x, -oo, 0))
        q_hatv2 = integrate(l**2 * x**2 * px, (x, 0, +oo)) + integrate((l*a*exp(b*x) - l*a)**2 * px, (x, -oo, 0))
        slopev2 = diff(phiv2, q)
        
        self.alpha_ = 1.
        self.beta = 1.
        self.lambda_ = 1.
        
        # functions to calculate phi, q, and slope
        self.phi_fn = lambdify([l, a, b, q], phiv2, "numpy")
        self.q_fn = lambdify([l, a, b, q], q_hatv2, "numpy")
        self.slope_fn = lambdify([l, a, b, q], slopev2, "numpy")
        
        self.multiplier = nn.Parameter(torch.ones(1))
        
         # get the optimal value
        self.optimize(lmax=1.5, step=0.001)
        
        # set lambda
        self.multiplier.data.fill_(self.lambda_)
    
    def get_coefficient(self, i, l):
        a_, b_ = i[0], i[1]
        return [
            self.q * self.phi_fn(l, a_, b_, self.q) - (1 + self.epsilon),
            self.q_fn(l, a_, b_, self.q) / self.fixpoint - 1.
        ]
    
    def optimize(self, lmax, step):
        l_list = np.arange(1., lmax, step)
        slope_min = 1
        
        for l_ in l_list:
            [a_, b_], infodict, ier, mesg = fsolve(self.get_coefficient, [1., 1.], l_, full_output=True)
            slope_ = self.slope_fn(l_, a_, b_, self.q)
            error = np.max(np.abs(infodict['fvec']))
            # print("l: %.3f, slop: %.3f" % (l, slope))
            if (slope_ < slope_min) and (b_ > 0) and (error < self.epsilon):
                self.alpha_ = a_
                self.beta = b_
                self.lambda_ = l_
                slope_min = slope_
        
        print('lambda: %.9f, alpha: %.9f, beta: %.9f' % (self.lambda_, self.alpha_, self.beta))
    
    def forward(self, input_):
        return fuse_sselu(input_, self.alpha_, self.beta, self.multiplier)
