import torch
import torch.nn.functional as F
import math
from prealign_utils import *

# input: FP32  (#batch, #tokens, in_features)
# weight: binary values in FP32 format (out_features, in_features)
def fp_linear_with_pre_alignment_correct(input, bweight, num_systolic_row):
    
    # get features size
    in_features = input.size(-1)
    out_features = bweight.size(0)

    # get hardware spec
    num_split = math.ceil(in_features / num_systolic_row)

    # split input and bweight
    input_chunk = input.chunk(num_split, dim=-1)
    bweight_chunk = bweight.chunk(num_split, dim=-1)

    # do compuation for each chunk
    output = None
    for i in range(0, num_split):
        # decompose input
        s_in, e_in, f_in = decompose_fp(input_chunk[i])

        # align significand
        e_aligned, shift_amount = align_exponent(e_in) # in debugging
        m_aligned = align_mantissa_correct(shift_amount, e_in, f_in)
        m_aligned = ((-1) ** s_in) * m_aligned # attach extra bits to m_aligned

        # reshape m_aglined
        if (m_aligned.dim() == 2): # input 2dim
           m_aligned = m_aligned.unsqueeze(1).expand(-1, out_features, -1) 
        elif (m_aligned.dim() == 3): # input 3dim
           m_aligned = m_aligned.unsqueeze(2).expand(-1, -1, out_features, -1) 
        else:
            assert False, "input dim of fp linear should be 2 or 3 but we got {input.dim()}"

        # do computation
        tmp = m_aligned * bweight_chunk[i].int()
        m_add = tmp.sum(-1) # as we used F.linear for accum --> no limit on accumulator size

        # normalize output
        out_tmp = normalize_decomposed_fp_add(m_add, e_aligned)

        # update final output
        if ( i == 0 ):
            output = out_tmp
        else:
            output += out_tmp

    return output


# normalize fp output
def normalize_decomposed_fp_add(m_add, e_add): # debugged
    
    # check e_add
    in_norm = e_add.gt(0)
    in_denorm = e_add.eq(0)

    # get sing
    s_out = m_add.lt(0).int()
    # get mantissa
    m_out = m_add.abs()

    # normalize exponent
    leading_one_bp =  get_leading_one_bp(m_out)
    shift_amount = leading_one_bp - num_mbits
    e_out = e_add + shift_amount - num_mbits

    # check output loc
    out_zero = e_out.eq(0)
    out_norm = e_out.gt(0) | (out_zero & leading_one_bp.ge(2*num_mbits))
    out_denorm = e_out.lt(0) | (out_zero & leading_one_bp.lt(2*num_mbits))
    
    norm_loc = out_norm

    # update exponent for
    #denorm-to-norm
    e_out[in_denorm & out_norm]+= 1
    #any-to-denorm
    shift_amount[out_denorm] = shift_amount[out_denorm] - e_out[out_denorm]
    e_out[out_denorm] = 0
    #norm-to-denorm
    shift_amount[in_norm & out_denorm] += 1

    
    # update fraction
    f_out = m_out.clone()

    
    f_out[shift_amount.gt(0)] = m_out[shift_amount.gt(0)] >> shift_amount[shift_amount.gt(0)]
    f_out[shift_amount.lt(0)] = m_out[shift_amount.lt(0)] << shift_amount[shift_amount.lt(0)].abs()

    # subtract mantissa_upbound for normals
    f_out[norm_loc] = f_out[norm_loc] - mantissa_upbound
    f_out = f_out.int()


    # get grs
    g = m_out.bitwise_and((2 ** (shift_amount - 1)).int()).gt(0)
    r = m_out.bitwise_and((2 ** (shift_amount - 2)).int()).gt(0)
    # get s
    tmp_loc = shift_amount.ge(3)
    s = g.clone().fill_(False)
    s[tmp_loc] = (m_out[tmp_loc] % (2 ** (shift_amount[tmp_loc] - 2))).gt(0)

    # rounding
    f1 = f_out.bitwise_and(1).gt(0)
    roundUp_loc = ( g & (r | s) ) | ( f1 & g & ((~r) & (~s)) )  #g=1 and r or s=1 / g=1 and rs"00" and f1=1
    #roundUp_loc = round_loc & roundUp_loc
    f_out[roundUp_loc] += 1

    # re-norm
    renorm_loc = f_out.ge(mantissa_upbound)
    f_out[renorm_loc] = (f_out[renorm_loc] << 1) + g[renorm_loc].int()
    e_out[renorm_loc] += 1
    
    return reconstruct_fp(s_out, e_out, f_out)

# align mantissa
def align_mantissa_correct(shift_amount, e_in, f_in): #NOTE: debugged
    
    # get location of denorm & norm
    norm_loc = e_in.gt(0)
    denorm_loc = e_in.eq(0)
    
    # upset shift_amount
    shift = shift_amount.clone()
    shift[shift.ge(48)] = 48
    
    # align mentissa
    f_in = (f_in.long()) << num_mbits
    m_align = f_in.clone()
    m_align[norm_loc] += (mantissa_upbound << num_mbits)

    m_align = m_align >> shift
    
    return m_align


# old version
# # input: FP32  (#batch, #tokens, in_features)
# # weight: binary values in FP32 format (out_features, in_features)
# def fp_linear_with_pre_alignment_correct(input, bweight, num_systolic_row):
    
#     # get features size
#     in_features = input.size(-1)
#     out_features = bweight.size(0)

#     # get hardware spec
#     num_split = math.ceil(in_features / num_systolic_row)

#     # split input and bweight
#     input_chunk = input.chunk(num_split, dim=-1)
#     bweight_chunk = bweight.chunk(num_split, dim=-1)

#     # do compuation for each chunk
#     output = None
#     for i in range(0, num_split):
#         # decompose input
#         s_in, e_in, f_in = decompose_fp(input_chunk[i])

#         # align significand
#         e_aligned, shift_amount = align_exponent(e_in)
#         m_aligned = align_mantissa_correct(shift_amount, e_in, f_in)
#         m_aligned = ((-1) ** s_in) * m_aligned # attach extra bits to m_aligned

#         # reshape m_aglined
#         if (m_aligned.dim() == 2): # input 2dim
#            m_aligned = m_aligned.unsqueeze(1).expand(-1, out_features, -1) 
#         elif (m_aligned.dim() == 3): # input 3dim
#            m_aligned = m_aligned.unsqueeze(2).expand(-1, -1, out_features, -1) 
#         else:
#             assert False, "input dim of fp linear should be 2 or 3 but we got {input.dim()}"

#         # do computation
#         tmp = m_aligned * bweight_chunk[i].int()
#         m_add = tmp.sum(-1) # as we used F.linear for accum --> no limit on accumulator size

#         # normalize output
#         out_tmp = normalize_decomposed_fp_add(m_add, e_aligned)

#         # update final output
#         if ( i == 0 ):
#             output = out_tmp
#         else:
#             output += out_tmp

#     return output


# # normalize fp output
# def normalize_decomposed_fp_add(m_add, e_add):

#     # get sing
#     s_out = m_add.lt(0).int()
#     # get mantissa
#     m_out = m_add.abs()

#     # normalize exponent
#     leading_one_bp =  get_leading_one_bp(m_out)
#     shift_amount = leading_one_bp - num_mbits
#     e_out = e_add + shift_amount - num_mbits
#     denorm_loc = e_out.le(0)
#     norm_loc = e_out.gt(0)

#     # update exponent for denormals
#     shift_amount[denorm_loc] = shift_amount[denorm_loc] - e_out[denorm_loc] + 1
#     e_out[denorm_loc] = 0

#     # update fraction
#     f_out = m_out.clone()
#     f_out[shift_amount.gt(0)] = m_out[shift_amount.gt(0)] >> shift_amount[shift_amount.gt(0)]
#     f_out[shift_amount.lt(0)] = m_out[shift_amount.lt(0)] << shift_amount[shift_amount.lt(0)].abs()
#     # subtract mantissa_upbound for normals
#     f_out[norm_loc] = f_out[norm_loc] - mantissa_upbound
#     f_out = f_out.int()

#     # get grs
#     #round_loc = shift_amount.lt(0)
#     #shift_amount = shift_amount.abs()
#     g = m_out.bitwise_and(2 ** (shift_amount - 1)).gt(0)
#     r = m_out.bitwise_and(2 ** (shift_amount - 2)).gt(0)
#     # get s
#     tmp_loc = shift_amount.ge(3)
#     s = g.clone().fill_(False)
#     s[tmp_loc] = (m_out[tmp_loc] % (2 ** (shift_amount[tmp_loc] - 2))).gt(0)

#     # rounding
#     f1 = f_out.bitwise_and(1).gt(0)
#     roundUp_loc = ( g & (r | s) ) | ( f1 & g & ((~r) & (~s)) )  #g=1 and r or s=1 / g=1 and rs"00" and f1=1
#     #roundUp_loc = round_loc & roundUp_loc
#     f_out[roundUp_loc] += 1

#     # re-norm
#     renorm_loc = f_out.ge(mantissa_upbound)
#     f_out[renorm_loc] = (f_out[renorm_loc] << 1) + g[renorm_loc].int()
#     e_out[renorm_loc] += 1
    
#     return reconstruct_fp(s_out, e_out, f_out)

# # align mantissa
# def align_mantissa_correct(shift_amount, e_in, f_in):
    
#     # get location of denorm & norm
#     norm_loc = e_in.gt(0)
#     denorm_loc = e_in.eq(0)
    
#     # align mentissa
#     f_in = (f_in.long()) << num_mbits
#     m_align = f_in.clone() 
#     m_align[norm_loc] = ((mantissa_upbound << num_mbits) + f_in[norm_loc]) >> shift_amount[norm_loc]
#     m_align[denorm_loc] = f_in[denorm_loc] >> shift_amount[denorm_loc]
    
#     return m_align

