import numpy as np


_list110 = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64]
_list110 = np.append(_list110, 64)
_list110 = np.append(_list110, 10)


# rrelu on parallel path (In a block both are rrelu)
_listrrelu = [11, 16, 15, 16, 16, 16, 15, 16, 14, 16, 15, 16, 13, 16, 14, 16, 12, 16, 14, 16, 14, 16, 13, 16, 12, 16, 14, 16, 15, 16, 7, 16, 11, 16, 12, 16, 8, 32, 31, 32, 30, 32, 30, 32, 32, 32, 30, 32, 32, 32, 32, 32, 32, 32, 32, 32, 28, 32, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 29, 32, 31, 32, 32, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64]
_listrrelu = np.append(_listrrelu, 64)
_listrrelu = np.append(_listrrelu, 10)
## resnet56 Cifar100 gamma=0.04
for i in range(len(_list110)):
    print(i, _list110[i], _listrrelu[i])

## Network slimming
k = 3
h_out_h_0 = 32
h_out_w_0 = 32
h_out_h_1 = 16
h_out_w_1 = 16
h_out_h_2 = 8
h_out_w_2 = 8

flop110 = 0
memory110 = 0
flop = 0
memory = 0

# For resnet110
for i in range(len(_listrrelu)):  
    if i==0:
        flop110 += 2*_list110[i]*3*3*3*32*32
        memory110 += _list110[i]*3*3*3
        flop += 2*_listrrelu[i]*3*3*3*32*32
        memory += _listrrelu[i]*3*3*3

    if i>0 and i<=36:
        flop110 += 2*_list110[i]*_list110[i-1]*3*3*32*32
        memory110 += _list110[i]*_list110[i-1]*3*3
        flop += 2*_listrrelu[i]*_listrrelu[i-1]*3*3*32*32
        memory += _listrrelu[i]*_listrrelu[i-1]*3*3
            
    
    if i>36 and i<=72:
        flop110 += 2*_list110[i]*_list110[i-1]*3*3*16*16
        memory110 += _list110[i]*_list110[i-1]*3*3
        flop += 2*_listrrelu[i]*_listrrelu[i-1]*3*3*16*16
        memory += _listrrelu[i]*_listrrelu[i-1]*3*3
        
    if i>72 and i<=108:
        flop110 += 2*_list110[i]*_list110[i-1]*3*3*8*8
        memory110 += _list110[i]*_list110[i-1]*3*3
        flop += 2*_listrrelu[i]*_listrrelu[i-1]*3*3*8*8
        memory += _listrrelu[i]*_listrrelu[i-1]*3*3

    if i==109:
        flop110 += 2*_list110[i]*_list110[i-1]
        memory110 += _list110[i]*_list110[i-1]
        flop += 2*_listrrelu[i]*_listrrelu[i-1]
        memory += _listrrelu[i]*_listrrelu[i-1]

print(flop110, flop, memory110, memory)


# ######### ResNet164 ##############################################################################
# _list164 = [16, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 16, 16, 64, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 32, 32, 128, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64, 256, 64, 64]
# _list164 = np.append(_list164, 256)
# _list164 = np.append(_list164, 10)

# # _listrrelu = [8, 16, 16, 5, 9, 13, 14, 16, 16, 5, 10, 11, 3, 5, 7, 11, 16, 16, 12, 16, 16, 7, 14, 16, 4, 9, 14, 11, 14, 16, 3, 5, 12, 10, 11, 11, 15, 15, 15, 10, 15, 14, 4, 10, 14, 2, 14, 16, 2, 12, 16, 1, 4, 4, 40, 32, 32, 28, 30, 32, 31, 32, 32, 25, 30, 32, 31, 30, 32, 22, 32, 32, 26, 31, 30, 28, 31, 32, 35, 32, 32, 36, 32, 32, 28, 30, 31, 38, 31, 32, 36, 32, 32, 41, 32, 30, 43, 32, 29, 25, 29, 28, 38, 32, 28, 44, 32, 28, 114, 64, 64, 91, 64, 56, 61, 64, 52, 93, 64, 53, 144, 64, 48, 136, 64, 43, 120, 64, 39, 98, 64, 29, 174, 64, 41, 158, 64, 36, 154, 64, 32, 144, 64, 25, 157, 64, 29, 154, 64, 21, 145, 64, 23, 151, 64, 25, 138, 64, 20, 146, 64, 20]
# # _listrrelu = [1, 8, 16, 7, 13, 14, 7, 11, 11, 8, 14, 16, 5, 9, 11, 11, 14, 14, 4, 9, 15, 4, 13, 14, 13, 16, 16, 14, 16, 16, 1, 4, 6, 9, 13, 14, 8, 16, 15, 8, 12, 13, 20, 15, 15, 7, 12, 16, 1, 5, 9, 9, 11, 14, 31, 32, 32, 24, 29, 32, 15, 26, 31, 32, 31, 30, 22, 30, 32, 33, 32, 32, 33, 32, 31, 35, 32, 32, 39, 32, 32, 19, 31, 29, 48, 32, 31, 36, 32, 29, 35, 32, 28, 33, 31, 30, 32, 32, 26, 13, 17, 19, 21, 23, 20, 35, 32, 22, 104, 64, 57, 94, 64, 51, 80, 64, 38, 102, 64, 52, 96, 64, 42, 149, 64, 40, 147, 64, 38, 134, 64, 34, 148, 64, 29, 127, 63, 26, 146, 64, 24, 154, 64, 27, 129, 64, 25, 142, 64, 25, 129, 63, 20, 109, 64, 14, 111, 62, 16, 110, 61, 17]
# _listrrelu = [1, 5, 15, 6, 10, 15, 14, 15, 16, 17, 16, 16, 13, 16, 16, 6, 15, 16, 5, 10, 12, 14, 16, 16, 13, 16, 16, 20, 16, 16, 13, 16, 16, 16, 16, 16, 11, 14, 16, 17, 16, 16, 11, 15, 16, 2, 7, 16, 14, 16, 16, 8, 13, 14, 31, 32, 32, 49, 32, 32, 36, 32, 32, 52, 32, 32, 55, 32, 32, 56, 32, 32, 50, 32, 32, 55, 32, 32, 37, 32, 32, 62, 32, 32, 56, 32, 32, 68, 32, 32, 58, 32, 32, 54, 32, 32, 73, 32, 32, 51, 32, 32, 65, 32, 32, 50, 32, 32, 121, 64, 64, 203, 64, 64, 205, 64, 63, 216, 64, 64, 210, 64, 61, 225, 64, 62, 238, 64, 63, 239, 64, 62, 240, 64, 64, 239, 64, 63, 242, 64, 63, 247, 64, 61, 247, 64, 63, 246, 64, 62, 246, 64, 60, 245, 64, 62, 233, 64, 57, 226, 64, 55]
# _listrrelu = np.append(_listrrelu, 256)
# _listrrelu = np.append(_listrrelu, 10)

# for i in range(len(_list164)):
#     print(i, _list164[i], _listrrelu[i])



# ## Network slimming
# k = 3
# h_out_h_0 = 32
# h_out_w_0 = 32
# h_out_h_1 = 16
# h_out_w_1 = 16
# h_out_h_2 = 8
# h_out_w_2 = 8

# flop164 = 0
# memory164 = 0
# flop = 0
# memory = 0

# ## For resnet164
# for i in range(len(_listrrelu)):
    
#     if i==0:
#         flop164 += 2*_list164[i]*3*3*3*32*32
#         memory164 += _list164[i]*3*3*3
#         flop += 2*_listrrelu[i]*3*3*3*32*32
#         memory += _listrrelu[i]*3*3*3

#     if i>0 and i<=54:
#         if i%3 == 0 or i%3 == 1:
#             flop164 += 2*_list164[i]*_list164[i-1]*1*1*32*32
#             memory164 += _list164[i]*_list164[i-1]*1*1
#             flop += 2*_listrrelu[i]*_listrrelu[i-1]*1*1*32*32
#             memory += _listrrelu[i]*_listrrelu[i-1]*1*1
#         elif i%3 == 2:
#             flop164 += 2*_list164[i]*_list164[i-1]*3*3*32*32
#             memory164 += _list164[i]*_list164[i-1]*3*3
#             flop += 2*_listrrelu[i]*_listrrelu[i-1]*3*3*32*32
#             memory += _listrrelu[i]*_listrrelu[i-1]*3*3
            
    
#     if i>54 and i<=108:
#         if i%3 == 0 or i%3 == 1:
#             flop164 += 2*_list164[i]*_list164[i-1]*1*1*16*16
#             memory164 += _list164[i]*_list164[i-1]*1*1
#             flop += 2*_listrrelu[i]*_listrrelu[i-1]*1*1*16*16
#             memory += _listrrelu[i]*_listrrelu[i-1]*1*1
#         elif i%3 == 2:
#             flop164 += 2*_list164[i]*_list164[i-1]*3*3*16*16
#             memory164 += _list164[i]*_list164[i-1]*3*3
#             flop += 2*_listrrelu[i]*_listrrelu[i-1]*3*3*16*16
#             memory += _listrrelu[i]*_listrrelu[i-1]*3*3
        
#     if i>108 and i<=162:
#         if i%3 == 0 or i%3 == 1:
#             flop164 += 2*_list164[i]*_list164[i-1]*1*1*8*8
#             memory164 += _list164[i]*_list164[i-1]*1*1
#             flop += 2*_listrrelu[i]*_listrrelu[i-1]*1*1*8*8
#             memory += _listrrelu[i]*_listrrelu[i-1]*1*1
#         elif i%3 == 2:
#             flop164 += 2*_list164[i]*_list164[i-1]*3*3*8*8
#             memory164 += _list164[i]*_list164[i-1]*3*3
#             flop += 2*_listrrelu[i]*_listrrelu[i-1]*3*3*8*8
#             memory += _listrrelu[i]*_listrrelu[i-1]*3*3

#     if i==163:
#         flop164 += 2*10*_list164[i-1]
#         memory164 += 10*_list164[i-1]
#         flop += 2*10*_listrrelu[i-1]
#         memory += 10*_listrrelu[i-1]

# print(flop164, flop, memory164, memory)