import jax
import scipy as osp
from autograd import elementwise_grad as egrad
from autograd import numpy as np
from autograd import scipy as sp
import scipy.integrate as sp_int
from scipy.integrate.quadrature import _cached_roots_legendre
import scipy.optimize as sp_opt

SELU_LAMBDA = 1.0507
SELU_ALPHA = 1.67326
STEP = 0.0005
NUM_POINTS = 100000
SIGMOID = sp.special.expit
roots = osp.special.roots_legendre(NUM_POINTS)

def swish(x):
  return x * SIGMOID(x)


def swish_der(x):
  return SIGMOID(x) * (1. + x * (1 - SIGMOID(x)))


def swish_der2(x):
  return SIGMOID(x) * (1. - SIGMOID(x)) * (2. + x * (1. - 2. * SIGMOID(x)))


def safe_softplus(x):
  """Numerically-stable softplus.
  """
  return np.log(1. + np.exp(-np.abs(x))) + np.maximum(x, 0)


# Definitions for basic activation functions and their derivatives:
# (could compute the latter using `egrad` probably)
NONLINEARITIES = {
    "tanh": {
        "fn": np.tanh,
        "der": lambda x: 1. - (np.tanh(x)**2),
        "curv": egrad(egrad(np.tanh)),
    },
    "selu": {
        "fn": lambda x: elu(x, SELU_ALPHA, SELU_LAMBDA),
        "der": lambda x: elu_der(x, SELU_ALPHA, SELU_LAMBDA),
        "curv": lambda x: elu_curv(x, SELU_ALPHA, SELU_LAMBDA),
    },
    "softplus": {
        "fn": safe_softplus,
        "der": SIGMOID,
        "curv": egrad(SIGMOID),
    },
    "swish": {
        "fn": lambda x: x * SIGMOID(x),
        "der": swish_der,
        "curv": egrad(swish_der),
    },
}

class ParameterizedNonlinearity(object):
  """A class for determining nonlinearity parameters."""

  def __init__(self, nonlin_str, input_scale, input_shift,
               output_shift, output_scale=1.0,):

    self.nonlin_str = nonlin_str

    nl = NONLINEARITIES[nonlin_str]
    self.phi, self.phi_der, self.phi_curv = nl["fn"], nl["der"], nl["curv"]

    self.params = {}
    self.params["input_scale"] = input_scale
    self.params["input_shift"] = input_shift
    self.params["output_scale"] = output_scale
    self.params["output_shift"] = output_shift

    # first output scale
    q = _calc_output_q(self)
    self.params["output_scale"] = np.sqrt(1. / q) * self.params["output_scale"]

    # curv1 is the second derivative the C map at c=1
    self.curv1 = _calc_curv1(self, q_output=1.0)

    # chi0 is the slope the C map at c=0
    self.chi0 = _calc_chi(self, target="chi0", q_output=1.0)

    # chi1 is the slope the C map at c=1
    self.chi1 = _calc_chi(self, target="chi1", q_output=1.0)

    self.q_output = _calc_output_q(self)

    # qslope1 is the slope the Q map at c=1
    self.qslope1 = _calc_qslope1(self)

  def fn(self, x):
    b = self.params["input_shift"]
    f = self.params["output_shift"]
    g = self.params["output_scale"]
    s = self.params["input_scale"]

    return g * (self.phi(s * x + b) + f)

  def der(self, x):
    b = self.params["input_shift"]
    g = self.params["output_scale"]
    s = self.params["input_scale"]
    return g * (self.phi_der(s * x + b)) * s

  def curv(self, x):
    b = self.params["input_shift"]
    g = self.params["output_scale"]
    s = self.params["input_scale"]
    return g * (self.phi_curv(s * x + b)) * s ** 2

  @property
  def chi_ratio(self):
    return self.chi1 / self.chi0


def _estimate_gaussian_mean(fn, n=NUM_POINTS):
  """Estimate the mean of a function fn(x) where x ~ N(0,1)."""
  _cached_roots_legendre.cache[n] = roots
  fn_weighted = lambda x: np.exp(-x**2 / 2) * fn(x)
  integral, _ = sp_int.fixed_quad(fn_weighted, -10., 10., n=n)

  return integral / np.sqrt(2 * np.pi)


def _calc_chi(nl, target="chi1", q_output=None):
  # Estimate result using numerical integration:

  if target == "chi0":
    d_int = _estimate_gaussian_mean(nl.der)**2
  elif target == "chi1":
    fn = lambda x: nl.der(x)**2
    d_int = _estimate_gaussian_mean(fn)

  if q_output is None:
    q_output = _calc_output_q(nl)

  chi = d_int / q_output

  return chi


def _calc_curv1(nl, q_output=None):
  fn = lambda x: nl.curv(x)**2
  int_ = _estimate_gaussian_mean(fn)
  if q_output is None:
    q_output = _calc_output_q(nl)

  return int_ / q_output


def _calc_qslope1(nl):
  # this assumes qin = 1

  fn = lambda x: nl.fn(x) * nl.der(x) * x
  int_ = _estimate_gaussian_mean(fn)
  return int_


def _calc_output_q(nl):
  fn = lambda x: nl.fn(x)**2
  int_ = _estimate_gaussian_mean(fn)
  return int_


def _eval_nonlin_properties(nonlin_str, input_scale, input_shift, output_shift):

  return ParameterizedNonlinearity(
      nonlin_str, input_scale=input_scale,
      input_shift=input_shift, output_shift=output_shift)

def _match_x_match_y_match_z(nonlin_str,
                             x_target,
                             y_target,
                             z_target,
                             x_string,
                             y_string,
                             z_string,
                             method="hybr",
                             options=None,
                             starting_point=(1.0, -0.5, 0.0)):

  # We are searching over input_scale and input_shift here.

  def func(v):
    input_scale = v[0]
    input_shift = v[1]
    output_shift = v[2]

    nl_prop = _eval_nonlin_properties(nonlin_str,
                                      input_scale=input_scale,
                                      input_shift=input_shift,
                                      output_shift=output_shift)

    return np.asarray([
        getattr(nl_prop, x_string) - x_target,
        getattr(nl_prop, y_string) - y_target,
        getattr(nl_prop, z_string) - z_target,
    ])

  sol = sp_opt.root(
      func,
      np.asarray(starting_point),
      method=method,
      jac=False,
      options=options)
  input_scale_match = sol.x[0]
  input_shift_match = sol.x[1]
  output_shift_match = sol.x[2]

  if not sol.success:
    raise ValueError("Root finding failed for given arguments: "
                     "nonlin_str={}, x_target={}, x_string={}, "
                     "y_target={}, y_string={}, z_target={}, z_string={}."
                     "".format(nonlin_str, x_target, x_string, y_target,
                               y_string, z_target, z_string))

  return _eval_nonlin_properties(
      nonlin_str,
      input_scale=input_scale_match,
      input_shift=input_shift_match,
      output_shift=output_shift_match)


def compute_nonlinearity_properties(nonlinearity, per_nl_curv1):

  always_try = (
      (1.0, 0.0, 0.0),
      (1.0, 1.0, 0.0),
      (1.0, -1.0, 0.0),
      (1.0, 0.0, 1.0),
      (1.0, 1.0, 1.0),
      (1.0, -1.0, 1.0),
      (1.0, 0.0, -1.0),
      (1.0, 1.0, -1.0),
      (1.0, -1.0, -1.0),
      (0.1, 0.0, 0.0),
      (0.1, 1.0, 0.0),
      (0.1, -1.0, 0.0),
      (0.1, 0.0, 1.0),
      (0.1, 1.0, 1.0),
      (0.1, -1.0, 1.0),
      (0.1, 0.0, -1.0),
      (0.1, 1.0, -1.0),
      (0.1, -1.0, -1.0),
  )
  nl_prop = None

  for i in range(50):

    if i < len(always_try):
      starting_point = always_try[i]
    else:
      starting_point = (np.random.uniform(low=0.0, high=2.0),
                        np.random.uniform(low=-3.0, high=3.0),
                        np.random.uniform(low=-3.0, high=3.0))
    try:
      nl_prop = _match_x_match_y_match_z(
          nonlinearity,
          1.0,
          1.0,
          per_nl_curv1,
          "chi1",
          "qslope1",
          "curv1",
          method="hybr",
          starting_point=starting_point)

      print("Found parameters for nonlinearity {} using starting "
            "point {}.".format(nonlinearity, starting_point))
      if nl_prop.params["input_scale"] * nl_prop.params["output_scale"] > 2.0:
        print("The solution is not good enough. Keep searching.")
        continue
      break
    except ValueError:
      print("Failed to find parameters for nonlinearity {} using "
            "starting point {}.".format(nonlinearity, starting_point))

  if nl_prop is None:
    raise ValueError("Failed to find parameters for "
                     "nonlinearity {}.".format(nonlinearity))

  return nl_prop.params

def get_transformed_activation(act_fn, depth, shortcut_weight, target_value):
  w = shortcut_weight**2  
  per_nl_curv1 = target_value / (5 + (1 - w) * (depth - 6))
  params = compute_nonlinearity_properties(act_fn, per_nl_curv1)
  # transforming the activation
  if act_fn == "tanh":
    activation = jnp.tanh
  else:
    activation = getattr(jax.nn, act_fn)
  transformed_activation = (lambda x: params["output_scale"] * 
                            (activation(params["input_scale"] * x + 
                                        params["input_shift"]) + 
                             params["output_shift"]))