import torch
import torch.nn.functional as F
from prealign_utils import *
import math

# conv -> unfold + MM + fold
# input: FP32  (N, Cin, H, W)
# weight: binary values in FP32 format (Cout, Cin, kH, kW)
def fp_conv2d_with_pre_alignment(input, bweight, num_systolic_row, 
                                    dilation=1, padding=0, stride=1, # information required for folding 
                                    num_extra_bits=3):
    # get conv info
    kernel_size = (bweight.size(2), bweight.size(3))
    h_out, w_out = get_conv2d_output_size(input.size(2), input.size(3), kernel_size, dilation, padding, stride)
    # unfold input
    input = F.unfold(input, kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)
    input = input.transpose(1,2) # (N, #out_spatial, Cin * H * W)
    
    # linear with pre_alignmnet
    output = fp_linear_with_pre_alignment(input, bweight, num_systolic_row, num_extra_bits)
    
    # fold output
    output = output.transpose(1, 2)
    output = F.fold(output, (h_out, w_out), (1, 1))

    return output


# input: FP32  (#batch, #tokens, in_features)
# weight: binary values in FP32 format (out_features, in_features)
def fp_linear_with_pre_alignment(input, bweight, num_systolic_row, num_extra_bits=3):
    
    # get features size
    in_features = input.size(-1)
    out_features = bweight.size(0)

    # get hardware spec
    num_split = math.ceil(in_features / num_systolic_row)

    # split input and bweight
    input_chunk = input.chunk(num_split, dim=-1)
    bweight = bweight.int()
    bweight_chunk = bweight.chunk(num_split, dim=-1)

    # free unsued memory
    torch.cuda.empty_cache()

    # do compuation for each chunk
    output = None
    for i in range(0, num_split):
        # decompose input
        s, e, m = decompose_fp(input_chunk[i])

        # align significand
        e_aligned, shift_amount = align_exponent(e)
        m_aligned = align_mantissa(shift_amount, e, m)
        extra_bits = get_extra_bits(shift_amount, m, num_extra_bits)
        m = ((-1) ** s) * ((m_aligned << num_extra_bits) + extra_bits) # attach extra bits to m_aligned

        # free unused memory
        del s
        del e
        del extra_bits
        torch.cuda.empty_cache()

        # reshape m_aglined
        if (m.dim() == 2): # input 2dim
           m = m.unsqueeze(1).expand(-1, out_features, -1) 
        elif (m.dim() == 3): # input 3dim
           m = m.unsqueeze(2).expand(-1, -1, out_features, -1) 
        else:
            assert False, "input dim of fp linear should be 2 or 3 but we got {input.dim()}"

        # do computation
        m_add = m * bweight_chunk[i]
        m_add = m_add.sum(-1) # as we used F.linear for accum --> no limit on accumulator size

        # free unused memory
        del m
        torch.cuda.empty_cache()

        # normalize output
        m_add = normalize_decomposed_fp_add(m_add, e_aligned, num_extra_bits)

        # update final output
        if ( i == 0 ):
            output = m_add
        else:
            output += m_add

        # free unused memory
        del m_add
        del e_aligned
        torch.cuda.empty_cache()

    return output

# normalize fp output
def normalize_decomposed_fp_add(m_add, e_add, num_extra_bits=3):

    # check e_add
    in_norm = e_add.gt(0)
    in_denorm = e_add.eq(0)
    
    # get sing
    s_out = m_add.lt(0).int()
    # get mantissa
    m_out = m_add.abs()

    # normalize exponent
    leading_one_bp =  get_leading_one_bp(m_out)
    shift_amount = leading_one_bp - num_mbits
    e_out = e_add + shift_amount - num_extra_bits

    # check output loc
    out_zero = e_out.eq(0)
    out_norm = e_out.gt(0) | (out_zero & leading_one_bp.ge(num_mbits + num_extra_bits))
    out_denorm = e_out.lt(0) | (out_zero & leading_one_bp.lt(num_mbits + num_extra_bits))
    
    norm_loc = out_norm

    # update exponent for
    #denorm-to-norm
    e_out[in_denorm & out_norm]+= 1
    #any-to-denorm
    shift_amount[out_denorm] = shift_amount[out_denorm] - e_out[out_denorm]
    e_out[out_denorm] = 0
    #norm-to-denorm
    shift_amount[in_norm & out_denorm] += 1

    # update fraction
    f_out = m_out.clone()
    f_out[shift_amount.gt(0)] = m_out[shift_amount.gt(0)] >> shift_amount[shift_amount.gt(0)]
    f_out[shift_amount.lt(0)] = m_out[shift_amount.lt(0)] << shift_amount[shift_amount.lt(0)].abs()
    # subtract mantissa_upbound for normals
    f_out[norm_loc] = f_out[norm_loc] - mantissa_upbound
    f_out = f_out.int()
    
    # get grs
    g = m_out.bitwise_and((2 ** (shift_amount - 1)).int()).gt(0)
    r = m_out.bitwise_and((2 ** (shift_amount - 2)).int()).gt(0)
    # get s
    tmp_loc = shift_amount.ge(3)
    s = g.clone().fill_(False)
    s[tmp_loc] = (m_out[tmp_loc] % (2 ** (shift_amount[tmp_loc] - 2))).gt(0)

    # rounding
    f1 = f_out.bitwise_and(1).gt(0)
    roundUp_loc = ( g & (r | s) ) | ( f1 & g & ((~r) & (~s)) )  #g=1 and r or s=1 / g=1 and rs"00" and f1=1
    #roundUp_loc = round_loc & roundUp_loc
    f_out[roundUp_loc] += 1

    # re-norm
    renorm_loc = f_out.ge(mantissa_upbound)
    f_out[renorm_loc] = (f_out[renorm_loc] << 1) + g[renorm_loc].int()
    e_out[renorm_loc] += 1
    
    return reconstruct_fp(s_out, e_out, f_out)

# align mantissa
def align_mantissa(shift_amount, e_in, f_in):
    
    # get location of denorm & norm
    norm_loc = e_in.gt(0)
    denorm_loc = e_in.eq(0)
    
    # upset shift_amount
    shift = shift_amount.clone()
    shift[shift.gt(31)] = 31
    
    # align mentissa
    m_align = f_in.clone()
    m_align[norm_loc] += mantissa_upbound

    m_align = m_align >> shift
    
    return m_align




# onld version
# # conv -> unfold + MM + fold
# # input: FP32  (N, Cin, H, W)
# # weight: binary values in FP32 format (Cout, Cin, kH, kW)
# def fp_conv2d_with_pre_alignment(input, bweight, num_systolic_row, 
#                                     dilation=1, padding=0, stride=1, # information required for folding 
#                                     num_extra_bits=3):
#     # get conv info
#     kernel_size = (bweight.size(2), bweight.size(3))
#     h_out, w_out = get_conv2d_output_size(input.size(2), input.size(3), kernel_size, dilation, padding, stride)
#     # unfold input
#     input = F.unfold(input, kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)
#     input = input.transpose(1,2) # (N, #out_spatial, Cin * H * W)
#     # unfold weight
#     bweight = bweight.int()
#     bweight = bweight.view(bweight.size(0), -1) # (Cout, Cin * H * W)
    
#     # get features size
#     in_features = input.size(-1)
#     c_out = bweight.size(0)

#     # get hardware spec
#     num_split = math.ceil(in_features / num_systolic_row)

#     # split input and bweight
#     input_chunk = input.chunk(num_split, dim=-1)
#     bweight_chunk = bweight.chunk(num_split, dim=-1)

#     # free unsued memory
#     torch.cuda.empty_cache()

#     # do compuation for each chunk
#     output = None
#     for i in range(0, num_split):
#         # decompose input
#         s, e, m = decompose_fp(input_chunk[i])

#         # align significand
#         e, m, extra_bits = align_significand(e, m)
#         m = ((-1) ** s) * ((m << num_extra_bits) + extra_bits) # attach extra bits to m_aligned

#         # free unused memory
#         del s
#         del extra_bits
#         torch.cuda.empty_cache()

#         # get sum of m
#         m_add = []
#         for c in range(0, c_out):
#             m_add.append((m * bweight_chunk[i][c,:]).sum(-1))
#         # free unused memory
#         del m
#         torch.cuda.empty_cache()
#         # stack sum
#         m_add = torch.stack(m_add, dim=2)

#         # normalize output (m - fp from here)
#         m_add = normalize_decomposed_fp_add(m_add, e, num_extra_bits)

#         # update final output
#         if ( i == 0 ):
#             output = m_add
#         else:
#             output += m_add

#         # free unused memory
#         del m_add
#         del e
#         torch.cuda.empty_cache()


#     # fold output
#     output = output.transpose(1, 2)
#     output = F.fold(output, (h_out, w_out), (1, 1))

#     return output

# # input: FP32  (#batch, #tokens, in_features)
# # weight: binary values in FP32 format (out_features, in_features)
# def fp_linear_with_pre_alignment(input, bweight, num_systolic_row, num_extra_bits=3):
    
#     # get features size
#     in_features = input.size(-1)
#     out_features = bweight.size(0)

#     # get hardware spec
#     num_split = math.ceil(in_features / num_systolic_row)

#     # split input and bweight
#     input_chunk = input.chunk(num_split, dim=-1)
#     bweight = bweight.int()
#     bweight_chunk = bweight.chunk(num_split, dim=-1)

#     # free unsued memory
#     torch.cuda.empty_cache()

#     # do compuation for each chunk
#     output = None
#     for i in range(0, num_split):
#         # decompose input
#         s, e, m = decompose_fp(input_chunk[i])

#         # align significand
#         e, m, extra_bits = align_significand(e, m)
#         m = ((-1) ** s) * ((m << num_extra_bits) + extra_bits) # attach extra bits to m_aligned

#         # free unused memory
#         del s
#         del extra_bits
#         torch.cuda.empty_cache()

#         # reshape m_aglined
#         if (m.dim() == 2): # input 2dim
#            m = m.unsqueeze(1).expand(-1, out_features, -1) 
#         elif (m.dim() == 3): # input 3dim
#            m = m.unsqueeze(2).expand(-1, -1, out_features, -1) 
#         else:
#             assert False, "input dim of fp linear should be 2 or 3 but we got {input.dim()}"

#         # do computation
#         m_add = m * bweight_chunk[i]
#         m_add = m_add.sum(-1) # as we used F.linear for accum --> no limit on accumulator size

#         # free unused memory
#         del m
#         torch.cuda.empty_cache()

#         # normalize output
#         m_add = normalize_decomposed_fp_add(m_add, e, num_extra_bits)

#         # update final output
#         if ( i == 0 ):
#             output = m_add
#         else:
#             output += m_add

#         # free unused memory
#         del m_add
#         del e
#         torch.cuda.empty_cache()

#     return output

# # normalize fp output
# def normalize_decomposed_fp_add(m_add, e_add, num_extra_bits=3):

#     # get sing
#     s_out = m_add.lt(0).int()
#     # get mantissa
#     m_out = m_add.abs() >> num_extra_bits # remove extra 3 bits after the computation

#     # normalize exponent
#     leading_one_bp =  get_leading_one_bp(m_out)
#     shift_amount = leading_one_bp - num_mbits
#     e_out = e_add + shift_amount
#     denorm_loc = e_out.le(0)
#     norm_loc = e_out.gt(0)

#     # update exponent for denormals
#     shift_amount[denorm_loc] = shift_amount[denorm_loc] - e_out[denorm_loc] + 1
#     e_out[denorm_loc] = 0

#     # update fraction
#     f_out = m_out.clone()
#     f_out[shift_amount.gt(0)] = m_out[shift_amount.gt(0)] >> shift_amount[shift_amount.gt(0)]
#     f_out[shift_amount.lt(0)] = m_out[shift_amount.lt(0)] << shift_amount[shift_amount.lt(0)].abs()
#     # subtract mantissa_upbound for normals
#     f_out[norm_loc] = f_out[norm_loc] - mantissa_upbound
    
#     return reconstruct_fp(s_out, e_out, f_out)

# # align mantissa
# def align_mantissa(shift_amount, e_in, f_in):
    
#     # get location of denorm & norm
#     norm_loc = e_in.gt(0)
#     denorm_loc = e_in.eq(0)
    
#     # align mentissa
#     m_align = f_in.clone()
#     m_align[norm_loc] = (mantissa_upbound + f_in[norm_loc]) >> shift_amount[norm_loc]
#     m_align[denorm_loc] = f_in[denorm_loc] >> shift_amount[denorm_loc]
    
#     return m_align
