#################################################################################
#                                                                               # 
#                                WARNING !!!                                    # 
#                                                                               # 
#       Source-code for NeurIPS 2021 submission to reproduce the results.       #    
#       When paper will be accepted, the code will be publicly available.       # 
#       Before that we are strongly requesting do not distribute the code.      #
#       Thank you!^^                                                            #
#                                                                               # 
#################################################################################
 


"""LUT-based quantization for softmax in Transformer model. """

import torch


exp_LUT_1x101_int16 =[  0,    232,    244,    257,    270,    283,    298,    313,    329,    346,     
                      364,    383,    402,    423,    445,    467,    491,    517,    543,    571,     
                      600,    631,    663,    697,    733,    771,    810,    852,    895,    941,     
                      990,    1040,   1094,   1150,   1209,   1271,   1336,   1404,   1476,   1552,     
                      1631,   1715,   1803,   1895,   1993,   2095,   2202,   2315,   2434,   2559,     
                      2690,   2828,   2973,   3125,   3285,   3454,   3631,   3817,   4013,   4218,     
                      4435,   4662,   4901,   5152,   5417,   5694,   5986,   6293,   6616,   6955,     
                      7312,   7686,   8080,   8495,   8930,   9388,   9870,   10376,  10908,  11467,     
                      12055,  12673,  13322,  14006,  14724,  15479,  16272,  17106,  17983,  18905,     
                      19875,  20894,  21965,  23091,  24275,  25520,  26828,  28204,  29650,  31170,  32767]   
                      
exp_LUT_1x101_uint8 =[0,	  2,	  2,    2,    2,    2,    2,    2,    3,    3,	
                      3,    3,    3,    3,    3,    4,    4,    4,    4,    4,	
                      5,    5,    5,    5,    6,    6,    6,    6,    7,    7,	
                      8,    8,    8,    9,    9,    10,   10,   11,   11,   12,	
                      12,   13,	  14,   14,	  15,   16,   17,   18,	  19,   20,	
                      21,   22,   23,   24,   25,	  26,   28,   29,   31,   32,	
                      34,   36,   37,   39,   41,   43,   46,   48,   50,	  53,	
                      56,   59,	  62,   65,   68,   72,   75,   79,   83,   87,	
                      92,   97,   102,  107,  112,  118,  124,  131,	137,  144,	
                      152,  159,  168,  176,  185,  195,  205,  215,  226,  238,  250]	

exp_LUT_1x101_int8 =[0,    1,    1,    1,    1,    1,    1,    1,    1,    1,               
                     1,    1,    2,    2,    2,    2,    2,    2,    2,    2,               
                     2,    2,    3,    3,    3,    3,    3,    3,    3,    4,               
                     4,    4,    4,    4,    5,    5,    5,    5,    6,    6,               
                     6,    7,    7,    7,    8,    8,    8,    9,    9,    10,               
                     10,   11,   11,   12,   13,   13,   14,   15,   15,   16,               
                     17,   18,   19,   20,   21,   22,   23,   24,   25,   27,               
                     28,   29,   31,   32,   34,   36,   38,   40,   42,   44,               
                     46,   48,   51,   53,   56,   59,   62,   65,   69,   72,               
                     76,   80,   84,   88,   93,   97,   102,  108,  113,  119,  125]                        

exp_LUT_1x101_uint6 =[                      
0,   0,   0,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
1,   1,   1,   1,   1,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,   3,
3,   3,   4,   4,   4,   4,   4,   5,   5,   5,   5,   6,   6,   6,   6,   7,   7,   7,   8,   8,
9,   9,   10,   10,   11,   11,   12,   12,   13,   14,   14,   15,   16,   17,   17,   18,   19,   20,   21,   22,
24,   25,   26,   27,   29,   30,   32,   33,   35,   37,   39,   41,   43,   45,   47,   50,   52,   55,   58,   61,
63   
]

exp_LUT_1x101_uint5 =[                      
0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1,
1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   2,
2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,   3,   3,   3,   4,   4,   4,   4,
4,   5,   5,   5,   5,   6,   6,   6,   6,   7,   7,   8,   8,   8,   9,   9,   10,   10,   11,   11,
12,   12,   13,   14,   14,   15,   16,   17,   18,   18,   19,   20,   21,   23,   24,   25,   26,   28,   29,   30,
31   
]
exp_LUT_1x101_uint4 =[                      
0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1,   1,   1,
1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   2,   2,   2,   2,
2,   2,   2,   3,   3,   3,   3,   3,   3,   3,   4,   4,   4,   4,   4,   5,   5,   5,   5,   6,
6,   6,   7,   7,   7,   8,   8,   8,   9,   9,   10,   10,   11,   11,   12,   12,   13,   14,   14,   15,
15
]

exp_LUT_1x101_int4 =[                      
0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,
3,   3,   3,   3,   4,   4,   4,   4,   4,   5,   5,   5,   5,   6,   6,   6,   7,   7,   7,   7,
7
]    

exp_LUT_1x101_uint2 =[                      
0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,
1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
1,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,   3,   3,   3,   3,   3,   3,
3
]

                          
Softmax_LUT_11x60_int16 =[[0,    3277,    6554,    9830,    13107,    16384,    19661,    22938,    26214,    29491,    32767],
                          [0,    1638,    3277,    4915,    6554,    8192,    9830,    11469,    13107,    14746,    16384],
                          [0,    1092,    2185,    3277,    4369,    5461,    6554,    7646,    8738,    9830,    10923],
                          [0,    819,    1638,    2458,    3277,    4096,    4915,    5734,    6554,    7373,    8192],
                          [0,    655,    1311,    1966,    2621,    3277,    3932,    4588,    5243,    5898,    6554],
                          [0,    546,    1092,    1638,    2185,    2731,    3277,    3823,    4369,    4915,    5461],
                          [0,    468,    936,    1404,    1872,    2341,    2809,    3277,    3745,    4213,    4681],
                          [0,    410,    819,    1229,    1638,    2048,    2458,    2867,    3277,    3686,    4096],
                          [0,    364,    728,    1092,    1456,    1820,    2185,    2549,    2913,    3277,    3641],
                          [0,    328,    655,    983,    1311,    1638,    1966,    2294,    2621,    2949,    3277],
                          [0,    298,    596,    894,    1192,    1489,    1787,    2085,    2383,    2681,    2979],
                          [0,    273,    546,    819,    1092,    1365,    1638,    1911,    2185,    2458,    2731],
                          [0,    252,    504,    756,    1008,    1260,    1512,    1764,    2016,    2269,    2521],
                          [0,    234,    468,    702,    936,    1170,    1404,    1638,    1872,    2107,    2341],
                          [0,    218,    437,    655,    874,    1092,    1311,    1529,    1748,    1966,    2185],
                          [0,    205,    410,    614,    819,    1024,    1229,    1434,    1638,    1843,    2048],
                          [0,    193,    386,    578,    771,    964,    1157,    1349,    1542,    1735,    1928],
                          [0,    182,    364,    546,    728,    910,    1092,    1274,    1456,    1638,    1820],
                          [0,    172,    345,    517,    690,    862,    1035,    1207,    1380,    1552,    1725],
                          [0,    164,    328,    492,    655,    819,    983,    1147,    1311,    1475,    1638],
                          [0,    156,    312,    468,    624,    780,    936,    1092,    1248,    1404,    1560],
                          [0,    149,    298,    447,    596,    745,    894,    1043,    1192,    1341,    1489],
                          [0,    142,    285,    427,    570,    712,    855,    997,    1140,    1282,    1425],
                          [0,    137,    273,    410,    546,    683,    819,    956,    1092,    1229,    1365],
                          [0,    131,    262,    393,    524,    655,    786,    918,    1049,    1180,    1311],
                          [0,    126,    252,    378,    504,    630,    756,    882,    1008,    1134,    1260],
                          [0,    121,    243,    364,    485,    607,    728,    850,    971,    1092,    1214],
                          [0,    117,    234,    351,    468,    585,    702,    819,    936,    1053,    1170],
                          [0,    113,    226,    339,    452,    565,    678,    791,    904,    1017,    1130],
                          [0,    109,    218,    328,    437,    546,    655,    765,    874,    983,    1092],
                          [0,    106,    211,    317,    423,    529,    634,    740,    846,    951,    1057],
                          [0,    102,    205,    307,    410,    512,    614,    717,    819,    922,    1024],
                          [0,    99,    199,    298,    397,    496,    596,    695,    794,    894,    993],
                          [0,    96,    193,    289,    386,    482,    578,    675,    771,    867,    964],
                          [0,    94,    187,    281,    374,    468,    562,    655,    749,    843,    936],
                          [0,    91,    182,    273,    364,    455,    546,    637,    728,    819,    910],
                          [0,    89,    177,    266,    354,    443,    531,    620,    708,    797,    886],
                          [0,    86,    172,    259,    345,    431,    517,    604,    690,    776,    862],
                          [0,    84,    168,    252,    336,    420,    504,    588,    672,    756,    840],
                          [0,    82,    164,    246,    328,    410,    492,    573,    655,    737,    819],
                          [0,    80,    160,    240,    320,    400,    480,    559,    639,    719,    799],
                          [0,    78,    156,    234,    312,    390,    468,    546,    624,    702,    780],
                          [0,    76,    152,    229,    305,    381,    457,    533,    610,    686,    762],
                          [0,    74,    149,    223,    298,    372,    447,    521,    596,    670,    745],
                          [0,    73,    146,    218,    291,    364,    437,    510,    583,    655,    728],
                          [0,    71,    142,    214,    285,    356,    427,    499,    570,    641,    712],
                          [0,    70,    139,    209,    279,    349,    418,    488,    558,    627,    697],
                          [0,    68,    137,    205,    273,    341,    410,    478,    546,    614,    683],
                          [0,    67,    134,    201,    267,    334,    401,    468,    535,    602,    669],
                          [0,    66,    131,    197,    262,    328,    393,    459,    524,    590,    655],
                          [0,    64,    129,    193,    257,    321,    386,    450,    514,    578,    643],
                          [0,    63,    126,    189,    252,    315,    378,    441,    504,    567,    630],
                          [0,    62,    124,    185,    247,    309,    371,    433,    495,    556,    618],
                          [0,    61,    121,    182,    243,    303,    364,    425,    485,    546,    607],
                          [0,    60,    119,    179,    238,    298,    357,    417,    477,    536,    596],
                          [0,    59,    117,    176,    234,    293,    351,    410,    468,    527,    585],
                          [0,    57,    115,    172,    230,    287,    345,    402,    460,    517,    575],
                          [0,    56,    113,    169,    226,    282,    339,    395,    452,    508,    565],
                          [0,    56,    111,    167,    222,    278,    333,    389,    444,    500,    555],
                          [0,    55,    109,    164,    218,    273,    328,    382,    437,    492,    546]]
        
Softmax_LUT_11x60_uint8 = [[0,	25,	50,	75,	100,	125,	150,	175,	200,	225,	250],
                            [0,	13,	25,	38,	50,	63,	75,	88,	100,	113,	125],
                            [0,	8,	17,	25,	33,	42,	50,	58,	67,	75,	83],
                            [0,	6,	13,	19,	25,	31,	38,	44,	50,	56,	63],
                            [0,	5,	10,	15,	20,	25,	30,	35,	40,	45,	50],
                            [0,	4,	8,	13,	17,	21,	25,	29,	33,	38,	42],
                            [0,	4,	7,	11,	14,	18,	21,	25,	29,	32,	36],
                            [0,	3,	6,	9,	13,	16,	19,	22,	25,	28,	31],
                            [0,	3,	6,	8,	11,	14,	17,	19,	22,	25,	28],
                            [0,	3,	5,	8,	10,	13,	15,	18,	20,	23,	25],
                            [0,	2,	5,	7,	9,	11,	14,	16,	18,	20,	23],
                            [0,	2,	4,	6,	8,	10,	13,	15,	17,	19,	21],
                            [0,	2,	4,	6,	8,	10,	12,	13,	15,	17,	19],
                            [0,	2,	4,	5,	7,	9,	11,	13,	14,	16,	18],
                            [0,	2,	3,	5,	7,	8,	10,	12,	13,	15,	17],
                            [0,	2,	3,	5,	6,	8,	9,	11,	13,	14,	16],
                            [0,	1,	3,	4,	6,	7,	9,	10,	12,	13,	15],
                            [0,	1,	3,	4,	6,	7,	8,	10,	11,	13,	14],
                            [0,	1,	3,	4,	5,	7,	8,	9,	11,	12,	13],
                            [0,	1,	3,	4,	5,	6,	8,	9,	10,	11,	13],
                            [0,	1,	2,	4,	5,	6,	7,	8,	10,	11,	12],
                            [0,	1,	2,	3,	5,	6,	7,	8,	9,	10,	11],
                            [0,	1,	2,	3,	4,	5,	7,	8,	9,	10,	11],
                            [0,	1,	2,	3,	4,	5,	6,	7,	8,	9,	10],
                            [0,	1,	2,	3,	4,	5,	6,	7,	8,	9,	10],
                            [0,	1,	2,	3,	4,	5,	6,	7,	8,	9,	10],
                            [0,	1,	2,	3,	4,	5,	6,	6,	7,	8,	9],
                            [0,	1,	2,	3,	4,	4,	5,	6,	7,	8,	9],
                            [0,	1,	2,	3,	3,	4,	5,	6,	7,	8,	9],
                            [0,	1,	2,	3,	3,	4,	5,	6,	7,	8,	8],
                            [0,	1,	2,	2,	3,	4,	5,	6,	6,	7,	8],
                            [0,	1,	2,	2,	3,	4,	5,	5,	6,	7,	8],
                            [0,	1,	2,	2,	3,	4,	5,	5,	6,	7,	8],
                            [0,	1,	1,	2,	3,	4,	4,	5,	6,	7,	7],
                            [0,	1,	1,	2,	3,	4,	4,	5,	6,	6,	7],
                            [0,	1,	1,	2,	3,	3,	4,	5,	6,	6,	7],
                            [0,	1,	1,	2,	3,	3,	4,	5,	5,	6,	7],
                            [0,	1,	1,	2,	3,	3,	4,	5,	5,	6,	7],
                            [0,	1,	1,	2,	3,	3,	4,	4,	5,	6,	6],
                            [0,	1,	1,	2,	3,	3,	4,	4,	5,	6,	6],
                            [0,	1,	1,	2,	2,	3,	4,	4,	5,	5,	6],
                            [0,	1,	1,	2,	2,	3,	4,	4,	5,	5,	6],
                            [0,	1,	1,	2,	2,	3,	3,	4,	5,	5,	6],
                            [0,	1,	1,	2,	2,	3,	3,	4,	5,	5,	6],
                            [0,	1,	1,	2,	2,	3,	3,	4,	4,	5,	6],
                            [0,	1,	1,	2,	2,	3,	3,	4,	4,	5,	5],
                            [0,	1,	1,	2,	2,	3,	3,	4,	4,	5,	5],
                            [0,	1,	1,	2,	2,	3,	3,	4,	4,	5,	5],
                            [0,	1,	1,	2,	2,	3,	3,	4,	4,	5,	5],
                            [0,	1,	1,	2,	2,	3,	3,	4,	4,	5,	5],
                            [0,	0,	1,	1,	2,	2,	3,	3,	4,	4,	5],
                            [0,	0,	1,	1,	2,	2,	3,	3,	4,	4,	5],
                            [0,	0,	1,	1,	2,	2,	3,	3,	4,	4,	5],
                            [0,	0,	1,	1,	2,	2,	3,	3,	4,	4,	5],
                            [0,	0,	1,	1,	2,	2,	3,	3,	4,	4,	5],
                            [0,	0,	1,	1,	2,	2,	3,	3,	4,	4,	4],
                            [0,	0,	1,	1,	2,	2,	3,	3,	4,	4,	4],
                            [0,	0,	1,	1,	2,	2,	3,	3,	3,	4,	4],
                            [0,	0,	1,	1,	2,	2,	3,	3,	3,	4,	4],
                            [0,	0,	1,	1,	2,	2,	3,	3,	3,	4,	4]]

Softmax_LUT_11x60_int8 = [[0,    13,    25,    38,    50,    63,    75,    88,    100,    113,    125],
                          [0,    6,    13,    19,    25,    31,    38,    44,    50,    56,    63],
                          [0,    4,    8,    13,    17,    21,    25,    29,    33,    38,    42],
                          [0,    3,    6,    9,    13,    16,    19,    22,    25,    28,    31],
                          [0,    3,    5,    8,    10,    13,    15,    18,    20,    23,    25],
                          [0,    2,    4,    6,    8,    10,    13,    15,    17,    19,    21],
                          [0,    2,    4,    5,    7,    9,    11,    13,    14,    16,    18],
                          [0,    2,    3,    5,    6,    8,    9,    11,    13,    14,    16],
                          [0,    1,    3,    4,    6,    7,    8,    10,    11,    13,    14],
                          [0,    1,    3,    4,    5,    6,    8,    9,    10,    11,    13],
                          [0,    1,    2,    3,    5,    6,    7,    8,    9,    10,    11],
                          [0,    1,    2,    3,    4,    5,    6,    7,    8,    9,    10],
                          [0,    1,    2,    3,    4,    5,    6,    7,    8,    9,    10],
                          [0,    1,    2,    3,    4,    4,    5,    6,    7,    8,    9],
                          [0,    1,    2,    3,    3,    4,    5,    6,    7,    8,    8],
                          [0,    1,    2,    2,    3,    4,    5,    5,    6,    7,    8],
                          [0,    1,    1,    2,    3,    4,    4,    5,    6,    7,    7],
                          [0,    1,    1,    2,    3,    3,    4,    5,    6,    6,    7],
                          [0,    1,    1,    2,    3,    3,    4,    5,    5,    6,    7],
                          [0,    1,    1,    2,    3,    3,    4,    4,    5,    6,    6],
                          [0,    1,    1,    2,    2,    3,    4,    4,    5,    5,    6],
                          [0,    1,    1,    2,    2,    3,    3,    4,    5,    5,    6],
                          [0,    1,    1,    2,    2,    3,    3,    4,    4,    5,    5],
                          [0,    1,    1,    2,    2,    3,    3,    4,    4,    5,    5],
                          [0,    1,    1,    2,    2,    3,    3,    4,    4,    5,    5],
                          [0,    0,    1,    1,    2,    2,    3,    3,    4,    4,    5],
                          [0,    0,    1,    1,    2,    2,    3,    3,    4,    4,    5],
                          [0,    0,    1,    1,    2,    2,    3,    3,    4,    4,    4],
                          [0,    0,    1,    1,    2,    2,    3,    3,    3,    4,    4],
                          [0,    0,    1,    1,    2,    2,    3,    3,    3,    4,    4],
                          [0,    0,    1,    1,    2,    2,    2,    3,    3,    4,    4],
                          [0,    0,    1,    1,    2,    2,    2,    3,    3,    4,    4],
                          [0,    0,    1,    1,    2,    2,    2,    3,    3,    3,    4],
                          [0,    0,    1,    1,    1,    2,    2,    3,    3,    3,    4],
                          [0,    0,    1,    1,    1,    2,    2,    3,    3,    3,    4],
                          [0,    0,    1,    1,    1,    2,    2,    2,    3,    3,    3],
                          [0,    0,    1,    1,    1,    2,    2,    2,    3,    3,    3],
                          [0,    0,    1,    1,    1,    2,    2,    2,    3,    3,    3],
                          [0,    0,    1,    1,    1,    2,    2,    2,    3,    3,    3],
                          [0,    0,    1,    1,    1,    2,    2,    2,    3,    3,    3],
                          [0,    0,    1,    1,    1,    2,    2,    2,    2,    3,    3],
                          [0,    0,    1,    1,    1,    1,    2,    2,    2,    3,    3],
                          [0,    0,    1,    1,    1,    1,    2,    2,    2,    3,    3],
                          [0,    0,    1,    1,    1,    1,    2,    2,    2,    3,    3],
                          [0,    0,    1,    1,    1,    1,    2,    2,    2,    3,    3],
                          [0,    0,    1,    1,    1,    1,    2,    2,    2,    2,    3],
                          [0,    0,    1,    1,    1,    1,    2,    2,    2,    2,    3],
                          [0,    0,    1,    1,    1,    1,    2,    2,    2,    2,    3],
                          [0,    0,    1,    1,    1,    1,    2,    2,    2,    2,    3],
                          [0,    0,    1,    1,    1,    1,    2,    2,    2,    2,    3],
                          [0,    0,    0,    1,    1,    1,    1,    2,    2,    2,    2],
                          [0,    0,    0,    1,    1,    1,    1,    2,    2,    2,    2],
                          [0,    0,    0,    1,    1,    1,    1,    2,    2,    2,    2],
                          [0,    0,    0,    1,    1,    1,    1,    2,    2,    2,    2],
                          [0,    0,    0,    1,    1,    1,    1,    2,    2,    2,    2],
                          [0,    0,    0,    1,    1,    1,    1,    2,    2,    2,    2],
                          [0,    0,    0,    1,    1,    1,    1,    2,    2,    2,    2],
                          [0,    0,    0,    1,    1,    1,    1,    2,    2,    2,    2],
                          [0,    0,    0,    1,    1,    1,    1,    1,    2,    2,    2],
                          [0,    0,    0,    1,    1,    1,    1,    1,    2,    2,    2]]

Softmax_LUT_11x60_uint6 = [
[0,   6,   13,   19,   26,   32,   38,   45,   51,   58,   63],
[0,   3,   6,   10,   13,   16,   19,   22,   26,   29,   32],
[0,   2,   4,   6,   9,   11,   13,   15,   17,   19,   21],
[0,   2,   3,   5,   6,   8,   10,   11,   13,   14,   16],
[0,   1,   3,   4,   5,   6,   8,   9,   10,   12,   13],
[0,   1,   2,   3,   4,   5,   6,   7,   9,   10,   11],
[0,   1,   2,   3,   4,   5,   5,   6,   7,   8,   9],
[0,   1,   2,   2,   3,   4,   5,   6,   6,   7,   8],
[0,   1,   1,   2,   3,   4,   4,   5,   6,   6,   7],
[0,   1,   1,   2,   3,   3,   4,   4,   5,   6,   6],
[0,   1,   1,   2,   2,   3,   3,   4,   5,   5,   6],
[0,   1,   1,   2,   2,   3,   3,   4,   4,   5,   5],
[0,   0,   1,   1,   2,   2,   3,   3,   4,   4,   5],
[0,   0,   1,   1,   2,   2,   3,   3,   4,   4,   5],
[0,   0,   1,   1,   2,   2,   3,   3,   3,   4,   4],
[0,   0,   1,   1,   2,   2,   2,   3,   3,   4,   4],
[0,   0,   1,   1,   2,   2,   2,   3,   3,   3,   4],
[0,   0,   1,   1,   1,   2,   2,   2,   3,   3,   4],
[0,   0,   1,   1,   1,   2,   2,   2,   3,   3,   3],
[0,   0,   1,   1,   1,   2,   2,   2,   3,   3,   3],
[0,   0,   1,   1,   1,   2,   2,   2,   2,   3,   3],
[0,   0,   1,   1,   1,   1,   2,   2,   2,   3,   3],
[0,   0,   1,   1,   1,   1,   2,   2,   2,   3,   3],
[0,   0,   1,   1,   1,   1,   2,   2,   2,   2,   3],
[0,   0,   1,   1,   1,   1,   2,   2,   2,   2,   3],
[0,   0,   0,   1,   1,   1,   1,   2,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   2,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   2,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   2,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   1,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   1,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   1,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   1,   2,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1]
]

Softmax_LUT_11x58_uint5 = [
[0,   3,   6,   10,   13,   16,   19,   22,   26,   29,   31],
[0,   2,   3,   5,   6,   8,   10,   11,   13,   14,   16],
[0,   1,   2,   3,   4,   5,   6,   7,   9,   10,   11],
[0,   1,   2,   2,   3,   4,   5,   6,   6,   7,   8],
[0,   1,   1,   2,   3,   3,   4,   4,   5,   6,   6],
[0,   1,   1,   2,   2,   3,   3,   4,   4,   5,   5],
[0,   0,   1,   1,   2,   2,   3,   3,   4,   4,   5],
[0,   0,   1,   1,   2,   2,   2,   3,   3,   4,   4],
[0,   0,   1,   1,   1,   2,   2,   2,   3,   3,   4],
[0,   0,   1,   1,   1,   2,   2,   2,   3,   3,   3],
[0,   0,   1,   1,   1,   1,   2,   2,   2,   3,   3],
[0,   0,   1,   1,   1,   1,   2,   2,   2,   2,   3],
[0,   0,   0,   1,   1,   1,   1,   2,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   2,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   1,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   1,   2,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1]
]

Softmax_LUT_11x16_uint4 = [
[0,   2,   3,   5,   6,   8,   10,   11,   13,   14,   15],
[0,   1,   2,   2,   3,   4,   5,   6,   6,   7,   8],
[0,   1,   1,   2,   2,   3,   3,   4,   4,   5,   5],
[0,   0,   1,   1,   2,   2,   2,   3,   3,   4,   4],
[0,   0,   1,   1,   1,   2,   2,   2,   3,   3,   3],
[0,   0,   1,   1,   1,   1,   2,   2,   2,   2,   3],
[0,   0,   0,   1,   1,   1,   1,   2,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   1,   1,   1,   1,   1,   1,   2,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1]
]

Softmax_LUT_11x16_int4 = [
[0,   1,   2,   2,   3,   4,   5,   6,   6,   7,   7],
[0,   0,   1,   1,   2,   2,   2,   3,   3,   4,   4],
[0,   0,   1,   1,   1,   1,   2,   2,   2,   2,   3],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1]
]


Softmax_LUT_11x8_uint2 = [
[0,   0,   1,   1,   2,   2,   2,   3,   3,   3,   3],
[0,   0,   0,   1,   1,   1,   1,   1,   2,   2,   2],
[0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1],
[0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1]
]


def softmax_LUT(x, LUT_e, LUT_s, scale):
    """ 
    LUT-based SoftMAX computation, to approximate and quantize the activations
    """
    
    exp_LUT = torch.torch.ShortTensor(LUT_e)  
    softmx_LUT = torch.torch.ShortTensor(LUT_s)
    
    x_idx_inp_min = 0
    scale_exp = 20
    bias_exp = 100
    scale_softmax = 10

    m = torch.max(x, dim = -1, keepdim = True) 
    x_inp = (x -  m[0])
    idx_inp_e_x = torch.round(x_inp * scale_exp) + bias_exp
    idx_inp_e_x = torch.clamp (idx_inp_e_x, min = x_idx_inp_min)

    ## 1st loop --> compute e_x
    shp = idx_inp_e_x.shape
    idx_inp_e_x = torch.reshape(idx_inp_e_x, [-1])
    idx_inp_e_x = torch.tensor(idx_inp_e_x, dtype=torch.long)

    e_x = torch.gather(exp_LUT, 0, idx_inp_e_x)
    e_x = torch.tensor(e_x, dtype=x.dtype)

    e_x = e_x / scale
    
    e_x = torch.reshape(e_x, shp)
    sum_e_x = torch.sum(e_x, -1)
    
    ## 2nd loop --> compute softmax 
    LUT_shape = softmx_LUT.shape
    idx_e_x = torch.round(scale_softmax * e_x)
    idx_e_x = torch.clamp(idx_e_x, min = 0)
    
    idx_sum_e_x = torch.round(sum_e_x) - 1
    idx_sum_e_x = torch.clamp(idx_sum_e_x, max = LUT_shape[0] - 1)
   
    final_shape = idx_e_x.shape

    idx_sum_e_x = torch.reshape(idx_sum_e_x, [-1,1])
    idx_sum_e_x = idx_sum_e_x.repeat(1, 1, idx_e_x.shape[-1])

    idx_sum_e_x = torch.reshape(idx_sum_e_x, [-1,1])
    idx_e_x = torch.reshape(idx_e_x, [-1,1])

    temp_temp = idx_sum_e_x * LUT_shape[1] + idx_e_x
    temp_temp = torch.tensor(temp_temp, dtype=torch.long)
    softmx_LUT_temp = torch.reshape(softmx_LUT, [-1])
    softmx_LUT_temp = softmx_LUT_temp.repeat(temp_temp.shape[0],1)
    
    y_LUT = torch.gather(softmx_LUT_temp, 1, temp_temp)
    y_LUT = torch.tensor(y_LUT, dtype=x.dtype)

#*************************************************************************
    y_LUT = y_LUT / scale
#*************************************************************************   
    y_LUT = torch.reshape(y_LUT, final_shape)
    
    return y_LUT
