import torch
import torch.nn.functional as F
import math

mantissa_upbound = 0x800000
num_fbits = 23
num_mbits = 24
e_bias = 127

# decompose FP32 values and extract sign/exponent/mantissa
def decompose_fp(input):

    # get exact zero position
    zero_loc = input.eq(0)

    # separate mantissa and exponent (m * 2^(-e) = val, so e is 1 less than exponent of IEEE standard format)
    m, e = torch.frexp(input)

    # get sign
    s = m.lt(0) # for low memory usuage -> s_out is binary (True - negative, False - positive)

    # get m for normalized format
    m = m.abs()
    f = (m * (2 ** num_mbits)).int() - mantissa_upbound

    # update e and get denorm_loc
    e = e + e_bias - 1
    denorm_loc = e.le(0) | zero_loc

    # get m for denormalized format 
    f[denorm_loc] = (m[denorm_loc] * (2 ** (num_fbits + e[denorm_loc]))).int()
    # update e for denormalized format
    e[denorm_loc] = 0

    return s, e, f


# remantissa_upboundruct FP values with s/e/m
def reconstruct_fp(s_in, e_in, f_in):

    # get location of denorm & norm
    norm_loc = e_in.gt(0)
    denorm_loc = e_in.eq(0)
    zero_loc = denorm_loc & (f_in.eq(0))

    # remantissa_upboundruct
    output = s_in * 0.0
    output[norm_loc] = ((-1) ** s_in[norm_loc]) * (1 + f_in[norm_loc] / 2 ** num_fbits) \
                        * (2.0 ** (e_in[norm_loc] - e_bias))
    output[denorm_loc] = ((-1) ** s_in[denorm_loc]) * (f_in[denorm_loc] / 2 ** num_fbits) \
                        * (2.0 ** (-e_bias + 1))
    output[zero_loc] = 0.0

    return output

# align exponent of values in the last dim
def align_exponent(e_in):
    
    # get max exponent
    e_aligned, _ = torch.max(e_in, dim=-1)
    e_aligned = e_aligned.unsqueeze(e_aligned.dim())
    # get shift amount
    shift_amount = e_aligned - e_in

    # increase shift amount by 1 for denorm with shift amount larger than 1 # added after debugging
    increase_one_loc = e_in.eq(0) & shift_amount.gt(0)
    shift_amount[increase_one_loc] = shift_amount[increase_one_loc] - 1
    
    return e_aligned, shift_amount


# extract 3 extra bits after shifting
def get_extra_bits(shift_amount, f_in, num_extra_bits=3): # 3 extra bits
    # NOTE: shift must be int tensor
    mask = None
    for i in range(0, num_extra_bits):
        if (i == 0):
            mask = 2 ** (shift_amount - i - 1)
        else:
            mask += 2 ** (shift_amount - i - 1) 
    if mask is None:
        mask = 0
    else:
        mask = mask.int()
    extra_bits = torch.bitwise_and(f_in, mask)

    # align extra bits
    # right shift
    loc_tmp = shift_amount.gt(num_extra_bits)
    extra_bits[loc_tmp] = extra_bits[loc_tmp] >> shift_amount[loc_tmp] - num_extra_bits
    # left shift
    loc_tmp = shift_amount.le(num_extra_bits)
    extra_bits[loc_tmp] = extra_bits[loc_tmp] << num_extra_bits - shift_amount[loc_tmp]

    return extra_bits


# get the bit position of leading one
def get_leading_one_bp(_val):

    # assign memory space
    val = _val.clone()
    leading_one_bp = val.clone().fill_(0)
    
    # get max leading one
    max_val = val.max()
    max_leading_one_bp = 0
    bit_position_mask = 0x0
    while (max_val >= 1):
        max_val = max_val / 2
        max_leading_one_bp += 1
        bit_position_mask  = (bit_position_mask << 1) + 1
    bit_position_mask = bit_position_mask >> 1
        
    # get proper leading one index for val
    for i in range(0, max_leading_one_bp):
        target_val_idx = val.gt(bit_position_mask)
        leading_one_bp[target_val_idx] = max_leading_one_bp - i
        val[target_val_idx] = 0
        bit_position_mask  = bit_position_mask >> 1 
        
    return leading_one_bp


# calculate output size of conv2d
def get_conv2d_output_size(h_in, w_in, kernel_size, dilation=1, padding=0, stride=1):
    # make all parameters as tuple
    if type(kernel_size) is not tuple:
        kernel_size = (kernel_size, kernel_size)
    if type(dilation) is not tuple:
        dilation = (dilation, dilation)
    if type(padding) is not tuple:
        padding = (padding, padding)
    if type(stride) is not stride:
        stride = (stride, stride)

    h_out =  (h_in + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1)) - 1 ) // stride[0] + 1 
    w_out =  (w_in + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1)) - 1 ) // stride[1] + 1 

    return h_out, w_out

