import math

import torch
import numpy as np


def he_interpolate(x, scale_factor=2):
    batch_size, channels, height, width = x.shape
    x = x.unsqueeze(-2).unsqueeze(-1)
    x = x.expand(batch_size, channels, height, scale_factor, width, scale_factor)
    x = x.reshape(batch_size, channels, height * scale_factor, width * scale_factor)
    return x


def he_exp(x, iters=8):
    r"""Approximates the exponential function using a limit approximation:

    .. math::

        exp(x) = \lim_{n \\rightarrow \\infty} (1 + x / n) ^ n

    Here we compute exp by choosing n = 2 ** d for some large d equal to
    `iterations`. We then compute (1 + x / n) once and square `d` times.

    Set the number of iterations for the limit approximation with
    config.exp_iterations.
    """  # noqa: W605
    result = 1 + x.div(2**iters)
    for _ in range(iters):
        result = result.square()
    return result


def he_reciprocal(self, input_in_01=False, all_pos=False, initial=None, nr_iters=10):
    r"""
    Args:
        input_in_01 (bool) : Allows a user to indicate that the input is in the range [0, 1],
                    causing the function optimize for this range. This is useful for improving
                    the accuracy of functions on probabilities (e.g. entropy functions).

    Methods:
        'NR' : `Newton-Raphson`_ method computes the reciprocal using iterations
                of :math:`x_{i+1} = (2x_i - self * x_i^2)` and uses
                :math:`3*exp(1 - 2x) + 0.003` as an initial guess by default

        'log' : Computes the reciprocal of the input from the observation that:
                :math:`x^{-1} = exp(-log(x))`

    Configuration params:
        reciprocal_method (str):  One of 'NR' or 'log'.
        reciprocal_nr_iters (int):  determines the number of Newton-Raphson iterations to run
                        for the `NR` method
        reciprocal_log_iters (int): determines the number of Householder
            iterations to run when computing logarithms for the `log` method
        reciprocal_all_pos (bool): determines whether all elements of the
            input are known to be positive, which optimizes the step of
            computing the sign of the input.
        reciprocal_initial (tensor): sets the initial value for the
            Newton-Raphson method. By default, this will be set to :math:
            `3*exp(-(x-.5)) + 0.003` as this allows the method to converge over
            a fairly large domain

    .. _Newton-Raphson:
        https://en.wikipedia.org/wiki/Newton%27s_method
    """
    if input_in_01:
        rec = he_reciprocal(self.mul(64), all_pos=True).mul(64)
        return rec

    if not all_pos:
        sgn = self.sign()
        pos = sgn * self
        return sgn * he_reciprocal(pos, all_pos=True)

    if initial is None:
        # Initialization to a decent estimate (found by qualitative inspection):
        #                1/x = 3exp(1 - 2x) + 0.003
        result = 3 * he_exp(1 - 2 * self) + 0.003
    else:
        result = initial
    for _ in range(nr_iters):
        if hasattr(result, "square"):
            result += result - result.square().mul_(self)
        else:
            result = 2 * result - result * result * self
    return result



def he_inv_sqrt(self, initial=None, iters=10, exp_iters=8):
    r"""
    Computes the inverse square root of the input using the Newton-Raphson method.

    Configuration params:
        sqrt_nr_iters (int):  determines the number of Newton-Raphson iterations to run.
        sqrt_nr_initial (tensor): sets the initial value for the Newton-Raphson iterations.
                    By default, this will be set to allow the method to converge over a
                    fairly large domain.

    .. _Newton-Raphson:
        https://en.wikipedia.org/wiki/Fast_inverse_square_root#Newton's_method
    """
    # Initialize using decent approximation
    if initial is None:
        y = he_exp(self.div(2).add(0.2).neg(), iters=exp_iters).mul(2.2).add(0.2)
        y -= self.div(1024)
    else:
        y = initial

    # Newton Raphson iterations for inverse square root
    for _ in range(iters):
        y = y.mul_(3 - self * y.square()).div_(2)
    return y


def he_softmax(x, dim=-1):

    x_max = x.max(dim=dim, keepdim=True)[0]
    x = x - x_max
    x_exp = he_exp(x, iters=10).div(np.sqrt(x.shape[dim] * (2**6)))
    x_sum = he_reciprocal(x_exp.sum(dim=dim, keepdim=True))
    x = x_exp * x_sum
    
    remainder = (1 - x.sum(dim=dim, keepdim=True)).div(x.shape[-1])
    x = x + remainder
    
    return x


def he_sigmoid(x, scale_bit=-4, exp_iters=10):
    r"""Computes the sigmoid function using the following definition

    .. math::
        \sigma(x) = (1 + e^{-x})^{-1}

    If a valid method is given, this function will compute sigmoid
        using that method:

    "chebyshev" - computes tanh via Chebyshev approximation with
        truncation and uses the identity:

    .. math::
        \sigma(x) = \frac{1}{2}tanh(\frac{x}{2}) + \frac{1}{2}

    "reciprocal" - computes sigmoid using :math:`1 + e^{-x}` and computing
        the reciprocal

    """  # noqa: W605
    ltz = (1-x.sign()) / 2
    # ltz: 0 pos, 1 neg
    x_min = x.mul(ltz)
    x = x - x_min
    x_exp = he_exp(x.neg(), iters=exp_iters)
    x_min_exp = he_exp(x_min, iters=exp_iters)
    denominator = x_exp + x_min_exp
    
    scale = (2**scale_bit)
    denominator = denominator.mul(scale)
    
    denominator = he_reciprocal(denominator)
    result = x_min_exp * denominator.mul(scale)
    
    pos_mask = (1-(1-result).sign()) / 2
    result = result * (1-pos_mask) + pos_mask
    
    neg_mask = (1-result.sign()) / 2
    result = result * (1-neg_mask)
    
    return result
    

def he_tanh(x, scale_bit=-6, exp_iters=10):
    r"""Computes the hyperbolic tangent function using the identity

    .. math::
        tanh(x) = 2\sigma(2x) - 1

    If a valid method is given, this function will compute tanh using that method:

    "chebyshev" - computes tanh via Chebyshev approximation with truncation.

    .. math::
        tanh(x) = \sum_{j=1}^terms c_{2j - 1} P_{2j - 1} (x / maxval)

    where c_i is the ith Chebyshev series coefficient and P_i is ith polynomial.
    The approximation is truncated to +/-1 outside [-1, 1].

    Args:
        terms (int): highest degree of Chebyshev polynomials.
                        Must be even and at least 6.
    """
    return he_sigmoid(x.mul(2), scale_bit=-6, exp_iters=exp_iters).mul(2).sub(1)


def he_erf(x, scale_bit=-6, exp_iters=10):
    r"""
    Approximates the error function of the input tensor using a Taylor approximation.
    """
    output = (x + x.pow(3).mul(11/123)).mul(2.0 / math.sqrt(math.pi))
    return he_tanh(output, scale_bit=scale_bit, exp_iters=exp_iters)
    