#################################################################################
#                                                                               # 
#                                WARNING !!!                                    # 
#                                                                               # 
#       Source-code for NeurIPS 2021 submission to reproduce the results.       #    
#       When paper will be accepted, the code will be publicly available.       # 
#       Before that we are strongly requesting do not distribute the code.      #
#       Thank you!^^                                                            #
#                                                                               # 
#################################################################################



import torch

rexp_LUT_1x13_int16 =[                      
32767,  12054,   4435,   1631,   600,   221,   81,
   30,     11,      4,      1,     1,     0
]

rexp_LUT_1x8_uint8 =[                      
255,   94,   35,   13,   5,   2,   1,   0
]

rexp_LUT_1x5_uint4 =[
15,  6,   2,   1,   0
]

rexp_LUT_1x3_uint2 =[
3,   1,   0
]



expln_LUT_1x512_int16 =[                      
32767,   16384,  10922,   8192,   6553,   5461,   4681,   4096,   3641,   3277,
 2979,    2731,   2521,   2341,   2184,   2048,   1927,   1820,   1725,   1638,
 1560,    1489,   1425,   1365,   1311,   1260,   1214,   1170,   1130,   1092,
 1057,    1024,    993,    964,    936,    910,    886,    862,    840,    819,
  799,     780,    762,    745,    728,    712,    697,    683,    669,    655,
  642,     630,    618,    607,    596,    585,    575,    565,    555,    546, 
  537,     529,    520,    512,    504,    496,    489,    482,    475,    468,
  462,     455,    449,    443,    437,    431,    426,    420,    415,    410, 
  405,     400,    395,    390,    385,    381,    377,    372,    368,    364,
  360,     356,    352,    349,    345,    341,    338,    334,    331,    328,
  324,     321,    318,    315,    312,    309,    306,    303,    301,    298,
  295,     293,    290,    287,    285,    282,    280,    278,    275,    273,
  271,     269,    266,    264,    262,    260,    258,    256,    254,    252,
  250,     248,    246,    245,    243,    241,    239,    237,    236,    234, 
  232,     231,    229,    228,    226,    224,    223,    221,    220,    218,
  217,     216,    214,    213,    211,    210,    209,    207,    206,    205,
  204,     202,    201,    200,    199,    197,    196,    195,    194,    193,
  192,     191,    189,    188,    187,    186,    185,    184,    183,    182, 
  181,     180,    179,    178,    177,    176,    175,    174,    173,    172,
  172,     171,    170,    169,    168,    167,    166,    165,    165,    164,
  163,     162,    161,    161,    160,    159,    158,    158,    157,    156,
  155,     155,    154,    153,    152,    152,    151,    150,    150,    149,
  148,     148,    147,    146,    146,    145,    144,    144,    143,    142,
  142,     141,    141,    140,    139,    139,    138,    138,    137,    137,
  136,     135,    135,    134,    134,    133,    133,    132,    132,    131,
  131,     130,    130,    129,    128,    128,    127,    127,    127,    126,
  126,     125,    125,    124,    124,    123,    123,    122,    122,    121,
  121,     120,    120,    120,    119,    119,    118,    118,    117,    117,
  117,     116,    116,    115,    115,    115,    114,    114,    113,    113,
  113,     112,    112,    111,    111,    111,    110,    110,    110,    109,
  109,     109,    108,    108,    107,    107,    107,    106,    106,    106,
  105,     105,    105,    104,    104,    104,    103,    103,    103,    102,
  102,     102,    101,    101,    101,    101,    100,    100,    100,     99,
   99,      99,     98,     98,     98,     98,     97,     97,     97,     96,
   96,      96,     96,     95,     95,     95,     94,     94,     94,     94,
   93,      93,     93,     93,     92,     92,     92,     92,     91,     91,
   91,      91,     90,     90,     90,     90,     89,     89,     89,     89,
   88,      88,     88,     88,     87,     87,     87,     87,     86,     86,
   86,      86,     86,     85,     85,     85,     85,     84,     84,     84,
   84,      84,     83,     83,     83,     83,     83,     82,     82,     82,
   82,      82,     81,     81,     81,     81,     81,     80,     80,     80,
   80,      80,     79,     79,     79,     79,     79,     78,     78,     78,
   78,      78,     77,     77,     77,     77,     77,     77,     76,     76,
   76,      76,     76,     76,     75,     75,     75,     75,     75,     74,
   74,      74,     74,     74,     74,     73,     73,     73,     73,     73,
   73,      72,     72,     72,     72,     72,     72,     72,     71,     71,
   71,      71,     71,     71,     70,     70,     70,     70,     70,     70,
   70,      69,     69,     69,     69,     69,     69,     69,     68,     68,
   68,      68,     68,     68,     68,     67,     67,     67,     67,     67,
   67,      67,     66,     66,     66,     66,     66,     66,     66,     66,
   65,      65,     65,     65,     65,     65,     65,     65,     64,     64,
   64,       0 ## the last value intentionally "zero"-ized 
]


expln_LUT_1x320_int16 =[                      
32767,   16384,  10922,   8192,   6553,   5461,   4681,   4096,   3641,   3277,
 2979,    2731,   2521,   2341,   2184,   2048,   1927,   1820,   1725,   1638,
 1560,    1489,   1425,   1365,   1311,   1260,   1214,   1170,   1130,   1092,
 1057,    1024,    993,    964,    936,    910,    886,    862,    840,    819,
  799,     780,    762,    745,    728,    712,    697,    683,    669,    655,
  642,     630,    618,    607,    596,    585,    575,    565,    555,    546, 
  537,     529,    520,    512,    504,    496,    489,    482,    475,    468,
  462,     455,    449,    443,    437,    431,    426,    420,    415,    410, 
  405,     400,    395,    390,    385,    381,    377,    372,    368,    364,
  360,     356,    352,    349,    345,    341,    338,    334,    331,    328,
  324,     321,    318,    315,    312,    309,    306,    303,    301,    298,
  295,     293,    290,    287,    285,    282,    280,    278,    275,    273,
  271,     269,    266,    264,    262,    260,    258,    256,    254,    252,
  250,     248,    246,    245,    243,    241,    239,    237,    236,    234, 
  232,     231,    229,    228,    226,    224,    223,    221,    220,    218,
  217,     216,    214,    213,    211,    210,    209,    207,    206,    205,
  204,     202,    201,    200,    199,    197,    196,    195,    194,    193,
  192,     191,    189,    188,    187,    186,    185,    184,    183,    182, 
  181,     180,    179,    178,    177,    176,    175,    174,    173,    172,
  172,     171,    170,    169,    168,    167,    166,    165,    165,    164,
  163,     162,    161,    161,    160,    159,    158,    158,    157,    156,
  155,     155,    154,    153,    152,    152,    151,    150,    150,    149,
  148,     148,    147,    146,    146,    145,    144,    144,    143,    142,
  142,     141,    141,    140,    139,    139,    138,    138,    137,    137,
  136,     135,    135,    134,    134,    133,    133,    132,    132,    131,
  131,     130,    130,    129,    128,    128,    127,    127,    127,    126,
  126,     125,    125,    124,    124,    123,    123,    122,    122,    121,
  121,     120,    120,    120,    119,    119,    118,    118,    117,    117,
  117,     116,    116,    115,    115,    115,    114,    114,    113,    113,
  113,     112,    112,    111,    111,    111,    110,    110,    110,    109,
  109,     109,    108,    108,    107,    107,    107,    106,    106,    106,
  105,     105,    105,    104,    104,    104,    103,    103,    103,     0 ## the last value intentionally "zero"-ized 
]


expln_LUT_1x256_int16 =[                      
32767,   16384,  10922,   8192,   6553,   5461,   4681,   4096,   3641,   3277,
 2979,    2731,   2521,   2341,   2184,   2048,   1927,   1820,   1725,   1638,
 1560,    1489,   1425,   1365,   1311,   1260,   1214,   1170,   1130,   1092,
 1057,    1024,    993,    964,    936,    910,    886,    862,    840,    819,
  799,     780,    762,    745,    728,    712,    697,    683,    669,    655,
  642,     630,    618,    607,    596,    585,    575,    565,    555,    546, 
  537,     529,    520,    512,    504,    496,    489,    482,    475,    468,
  462,     455,    449,    443,    437,    431,    426,    420,    415,    410, 
  405,     400,    395,    390,    385,    381,    377,    372,    368,    364,
  360,     356,    352,    349,    345,    341,    338,    334,    331,    328,
  324,     321,    318,    315,    312,    309,    306,    303,    301,    298,
  295,     293,    290,    287,    285,    282,    280,    278,    275,    273,
  271,     269,    266,    264,    262,    260,    258,    256,    254,    252,
  250,     248,    246,    245,    243,    241,    239,    237,    236,    234, 
  232,     231,    229,    228,    226,    224,    223,    221,    220,    218,
  217,     216,    214,    213,    211,    210,    209,    207,    206,    205,
  204,     202,    201,    200,    199,    197,    196,    195,    194,    193,
  192,     191,    189,    188,    187,    186,    185,    184,    183,    182, 
  181,     180,    179,    178,    177,    176,    175,    174,    173,    172,
  172,     171,    170,    169,    168,    167,    166,    165,    165,    164,
  163,     162,    161,    161,    160,    159,    158,    158,    157,    156,
  155,     155,    154,    153,    152,    152,    151,    150,    150,    149,
  148,     148,    147,    146,    146,    145,    144,    144,    143,    142,
  142,     141,    141,    140,    139,    139,    138,    138,    137,    137,
  136,     135,    135,    134,    134,    133,    133,    132,    132,    131,
  131,     130,    130,    129,    128,     0 ## the last value intentionally "zero"-ized 
]

expln_LUT_1x32_int16 =[                      
32767,   16384,  10922,   8192,   6553,   5461,   4681,   4096,   3641,   3277,
 2979,    2731,   2521,   2341,   2184,   2048,   1927,   1820,   1725,   1638, 
 1560,    1489,   1425,   1365,   1311,   1260,   1214,   1170,   1130,   1092,
 1057,       0 ## the last value intentionally "zero"-ized 
]

expln_LUT_1x16_int16 =[                      
32767,   16384,  10922,   8192,   6553,   5461,   4681,   4096,   3641,   3277,
 2979,    2731,   2521,   2341,   2184,      0 ## the last value intentionally "zero"-ized 
]




expln_LUT_1x512_uint8 =[                      
255,   128,   85,   64,   51,   43,   36,   32,   28,   26,   23,
 21,    20,   18,   17,   16,   15,   14,   13,   13,   12,   12, 
 11,    11,   10,   10,    9,    9,    9,    9,    8,    8,    8,
  8,     7,    7,    7,    7,    7,    6,    6,    6,    6,    6,
  6,     6,    5,    5,    5,    5,    5,    5,    5,    5,    5,
  5,     4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
  4,     4,    4,    4,    4,    4,    3,    3,    3,    3,    3,
  3,     3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
  3,     3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
  3,     3,    3,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    0,    0 ## the last value intentionally "zero"-ized 
]

expln_LUT_1x320_uint8 =[                      
255,   128,   85,   64,   51,   43,   36,   32,   28,   26,   23,
 21,    20,   18,   17,   16,   15,   14,   13,   13,   12,   12, 
 11,    11,   10,   10,    9,    9,    9,    9,    8,    8,    8,
  8,     7,    7,    7,    7,    7,    6,    6,    6,    6,    6,
  6,     6,    5,    5,    5,    5,    5,    5,    5,    5,    5,
  5,     4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
  4,     4,    4,    4,    4,    4,    3,    3,    3,    3,    3,
  3,     3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
  3,     3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
  3,     3,    3,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  0 ## the last value intentionally "zero"-ized 
]

expln_LUT_1x256_uint8 =[                      
255,   128,   85,   64,   51,   43,   36,   32,   28,   26,   23,
 21,    20,   18,   17,   16,   15,   14,   13,   13,   12,   12, 
 11,    11,   10,   10,    9,    9,    9,    9,    8,    8,    8,
  8,     7,    7,    7,    7,    7,    6,    6,    6,    6,    6,
  6,     6,    5,    5,    5,    5,    5,    5,    5,    5,    5,
  5,     4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
  4,     4,    4,    4,    4,    4,    3,    3,    3,    3,    3,
  3,     3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
  3,     3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
  3,     3,    3,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
  2,     2,    2,    2,    2,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
  1,     1,    0 ## the last value intentionally "zero"-ized 
]


expln_LUT_1x32_uint8 =[                      
255,   128,   85,   64,   51,   43,   36,   32,   28,   26,   23,
 21,    20,   18,   17,   16,   15,   14,   13,   13,   12,   12, 
 11,    11,   10,   10,    9,    9,    9,    9,    8,    0 ## the last value intentionally "zero"-ized 
]

expln_LUT_1x16_uint8 =[                      
255,   128,   85,   64,   51,   43,   36,   32,   28,   26,   23,
 21,    20,   18,   17,    0 ## the last value intentionally "zero"-ized 
]

expln_LUT_1x16_uint4 =[                      
15,   8,   5,   4,   3,   3,   2,   2,   2,   2,
 1,   1,   1,   1,   1,   0 ## the last value intentionally "zero"-ized 
]

expln_LUT_1x7_uint2 =[                      
3,   2,   1,   1,   1,   1,   0
]



def softmax_LUT_rexp(x, axis_max):
    """ 
    REXP LUT-based SoftMAX computation, to approximate and quantize the activations
    """
    ## select the LUT -(size) and appropriate parameters (scale, bias, etc.)##

#****************************** int16 x13 ************************************    
    rexp_LUT = torch.torch.ShortTensor(rexp_LUT_1x13_int16)  
    scale = 32768
#******************************* uint8 x8 ***********************************    
#    rexp_LUT = torch.torch.ShortTensor(rexp_LUT_1x8_uint8)  
#    scale = 256  # 2^w
#******************************* uint4 x5 ***********************************    
#    rexp_LUT = torch.torch.ShortTensor(rexp_LUT_1x5_uint4)  
#    scale = 16    # 2^w
#******************************* uint2 x3 ***********************************    
#    rexp_LUT = torch.torch.ShortTensor(rexp_LUT_1x3_uint2)  
#    scale = 4     # 2^w
#*****************************************************************************    

#******************************* int16 x256~x512 *****************************    
#    expln_LUT = torch.torch.ShortTensor(expln_LUT_1x512_int16)  
#    expln_LUT = torch.torch.ShortTensor(expln_LUT_1x320_int16)    
    expln_LUT = torch.torch.ShortTensor(expln_LUT_1x256_int16)    
    scale_eln = 32768   # 2^w
#******************************* uint8 x256~x512 *****************************    
#    expln_LUT = torch.torch.ShortTensor(expln_LUT_1x512_uint8)  
#    expln_LUT = torch.torch.ShortTensor(expln_LUT_1x320_uint8)  
#    expln_LUT = torch.torch.ShortTensor(expln_LUT_1x256_uint8)  
#    scale_eln = 256     # 2^w
#******************************* int16 x16 ***********************************    
#    expln_LUT = torch.torch.ShortTensor(expln_LUT_1x16_int16)  
#    scale_eln = 32768   # 2^w
#******************************* uint8 x16 ***********************************    
#    expln_LUT = torch.torch.ShortTensor(expln_LUT_1x16_uint8)  
#    scale_eln = 256     # 2^w
#******************************* uint4 x16 ***********************************    
#    expln_LUT = torch.torch.ShortTensor(expln_LUT_1x16_uint4)  
#    scale_eln = 16     # 2^w
#******************************* uint2 x7  ***********************************    
#    expln_LUT = torch.torch.ShortTensor(expln_LUT_1x7_uint2)  
#    scale_eln = 4     # 2^w
#*****************************************************************************    



    x = torch.tensor(x, dtype=x.dtype) 
     
    m = torch.max(x, dim = axis_max, keepdim = True) 
    
    x_inp = (m[0] - x)
    idx_inp_e_x = torch.round(x_inp)
    
    LUT_shape = rexp_LUT.shape
    idx_inp_e_x = torch.clamp (idx_inp_e_x, max = LUT_shape[0] - 1)

    ## 1st loop --> compute e_x
    shp = idx_inp_e_x.shape
    idx_inp_e_x = torch.reshape(idx_inp_e_x, [-1])
    idx_inp_e_x = torch.tensor(idx_inp_e_x, dtype=torch.long)
    
    e_x = torch.gather(rexp_LUT, 0, idx_inp_e_x)
    e_x = torch.tensor(e_x, dtype=x.dtype)

    e_x = e_x / scale
    
    e_x = torch.reshape(e_x, shp)
    
    ## accumulate Sum(e^x)
    sum_e_x = torch.sum(e_x, -1, keepdim=True)

    ## 2nd loop --> compute exp(ln(x)) 
    LUT_shape = expln_LUT.shape
    
    idx_sum_e_x = torch.round(sum_e_x) - 1
    idx_sum_e_x = torch.clamp(idx_sum_e_x, max = LUT_shape[0] - 1)

    idx_sum_e_x_shape = idx_sum_e_x.shape
    
    idx_sum_e_x = torch.reshape(idx_sum_e_x, [-1,1])
    
    temp_temp = idx_sum_e_x 
    temp_temp = torch.tensor(temp_temp, dtype=torch.long)
    eln_LUT_temp = torch.reshape(expln_LUT, [-1])
    eln_LUT_temp = eln_LUT_temp.repeat(temp_temp.shape[0],1)
    
    y_LUT = torch.gather(eln_LUT_temp, 1, temp_temp)
    y_LUT = torch.tensor(y_LUT, dtype=x.dtype)

    y_LUT = y_LUT / scale_eln

    y_LUT = torch.reshape(y_LUT, idx_sum_e_x_shape)
    y_LUT = y_LUT.repeat(1, 1, e_x.shape[-1])
    

    ## 3rd loop (final) --> compute softmax
#*************************************************************************
    y_LUT = e_x * y_LUT
#*************************************************************************

    return y_LUT

