import jax.numpy as np
import jax
import math
import functools

CONFIGS = {
      50: {
          "blocks_per_group": (3, 4, 6, 3),
          "bottleneck": True,
          "channels_per_group": (256, 512, 1024, 2048),
          "use_projection": (True, True, True, True),
      },
      101: {
          "blocks_per_group": (3, 4, 23, 3),
          "bottleneck": True,
          "channels_per_group": (256, 512, 1024, 2048),
          "use_projection": (True, True, True, True),
      },
      152: {
          "blocks_per_group": (3, 8, 36, 3),
          "bottleneck": True,
          "channels_per_group": (256, 512, 1024, 2048),
          "use_projection": (True, True, True, True),
      },
      200: {
          "blocks_per_group": (3, 24, 36, 3),
          "bottleneck": True,
          "channels_per_group": (256, 512, 1024, 2048),
          "use_projection": (True, True, True, True),
      },
}

def lrelu_kernel(c, alpha=0.0):
  return ((1 - alpha) ** 2 * 
          (np.sqrt(1 - c ** 2) + (math.pi - np.arccos(c)) * c) / math.pi 
          + 2 * alpha * c) / (1 + alpha ** 2)

def global_cmap_fn(local_cmap_fn, depth, c_init=0.0, shortcut_weight=0.0):
    blocks_per_group = CONFIGS[depth]["blocks_per_group"]
    bottleneck = CONFIGS[depth]["bottleneck"]
    use_projection = CONFIGS[depth]["use_projection"]
    c = c_init
    for i in range(4):
      for j in range(blocks_per_group[i]):
        if bottleneck:
          main_c = local_cmap_fn(local_cmap_fn(local_cmap_fn(c)))
        else:
          main_c = local_cmap_fn(local_cmap_fn(c))
        res_c = local_cmap_fn(c) if (j == 0 and use_projection[i]) else c
        c = shortcut_weight**2 * res_c + (1.0 - shortcut_weight**2) * main_c
    return local_cmap_fn(c)

def binary_search(fn, target, input_=0.0, min_=-1.0, max_=1.0, tol=1e-6):
  value = fn(input_)

  if np.abs(value - target) < tol:
    return input_

  if value < target:
    new_input = 0.5 * (input_ + min_)
    max_ = input_
  elif value > target:
    if np.isinf(max_):
      new_input = input_ * 2
    else:
      new_input = 0.5 * (input_ + max_)
    min_ = input_

  return binary_search(fn, target, new_input, min_, max_, tol=tol)

# the key function getting transformed activation
def get_transformed_lrelu(depth, shortcut_weight, target_value):
  global_lrelu_cmap_fn = lambda alpha: global_cmap_fn(
      functools.partial(lrelu_kernel, alpha=alpha),
      depth, shortcut_weight=shortcut_weight)

  negative_slope = binary_search(global_lrelu_cmap_fn, target_value)
  transformed_lrelu = (lambda x: math.sqrt(2.0 / (1 + negative_slope**2)) * 
                       jax.nn.leaky_relu(x, negative_slope=negative_slope))
  return transformed_lrelu