import lowp._functions as f
import torch

# RNE
BF16_DEFAULT_ROUNDING_MODE = 3
DEFAULT_STOCH_ROUNDING_MIN = -(2 ** 31)
DEFAULT_STOCH_ROUNDING_MAX = 2**31 - 1


def mantissa(input):
    return f.Mantissa().apply(input.contiguous())


def quantemu(input, mode, inplace=False):
    return f.QuantEmu().apply(input.contiguous(), mode, inplace)


def truncate_bf16(input, inplace=False, roundingMode=BF16_DEFAULT_ROUNDING_MODE):
    return f.TruncateBF16().apply(input.contiguous(), inplace, roundingMode)


def truncate_grad_bf16(input, roundingMode=BF16_DEFAULT_ROUNDING_MODE, debugStr=None):
    return f.TruncateGradBF16().apply(input.contiguous(), roundingMode, debugStr)


def truncate_fp(input,  exp_width, man_width, exp_bias=None, inplace=False, roundingMode=0, min_noise=DEFAULT_STOCH_ROUNDING_MIN, max_noise=DEFAULT_STOCH_ROUNDING_MAX):
    if exp_bias is None:
        exp_bias = 2 ** (exp_width-1) - 1
    is_half = False
    if input.dtype == torch.half:
        assert not inplace
        input = input.float()
        is_half = True
    out = f.TruncateFP().apply(input.contiguous(), inplace, exp_width,
                               man_width, exp_bias, roundingMode, min_noise, max_noise)
    if is_half:
        out = out.half()
    return out


def truncate_fp8(input, inplace=False, exp_width=5, man_width=None, exp_bias=None,  roundingMode=0, min_noise=DEFAULT_STOCH_ROUNDING_MIN, max_noise=DEFAULT_STOCH_ROUNDING_MAX):
    if man_width is None:
        man_width = 7 - exp_width
    return truncate_fp(input, exp_width, man_width, exp_bias, inplace, roundingMode, min_noise, max_noise)


def truncate_grad_fp8(input, exp_width=5, man_width=None, exp_bias=None, roundingMode=0, min_noise=DEFAULT_STOCH_ROUNDING_MIN, max_noise=DEFAULT_STOCH_ROUNDING_MAX):
    if exp_bias is None:
        exp_bias = 2 ** (exp_width-1) - 1
    if man_width is None:
        man_width = 7 - exp_width
    return f.TruncateGradFP().apply(input, exp_width, man_width, exp_bias, roundingMode, min_noise, max_noise)


def bmm_bf16(x, y):
    return truncate_grad_bf16(torch.bmm(truncate_bf16(x), truncate_bf16(y)))


def matmul_bf16(x, y):
    return truncate_grad_bf16(torch.matmul(truncate_bf16(x), truncate_bf16(y)))


def add_bf16(*kargs):
    bf16_args = []
    for arg in kargs:
        bf16_args.append(truncate_bf16(arg))
    return truncate_grad_bf16(sum(bf16_args))


def mul_bf16(*kargs):
    mult = 1
    for i, arg in enumerate(kargs):
        if i == 0:
            mult = truncate_bf16(arg)
        else:
            mult = mult * truncate_bf16(arg)
    return truncate_grad_bf16(mult)


def sigmoid_bf16(x):
    return truncate_grad_bf16(torch.sigmoid(truncate_bf16(x)))


def tanh_bf16(x):
    return truncate_grad_bf16(torch.tanh(truncate_bf16(x)))


def convert_bf16(input):
    output = truncate_bf16(input)
    return truncate_grad_bf16(output)
