import numpy as np


# Setting 1
cifar10_cifar100_dict = {
    'papernot': {
        'resnet50': {
            'accuracy': [24.05, 24.68, 25.61, 21.94, 23.56],
            'fidelity': [23.94, 24.83, 25.60, 22.18, 23.47],
            'asr': [31.52, 33.33, 36.11, 39.46, 38.38]
        },
        'pyramidnet': {
            'accuracy': [20.99, 23.12, 21.17, 23.35, 25.62],
            'fidelity': [20.99, 23.01, 20.93, 23.23, 25.64],
            'asr': [31.52, 29.44, 35.00, 35.96, 30.85]
        },
        'resnext': {
            'accuracy': [23.63, 22.83, 22.90, 22.73, 24.37],
            'fidelity': [23.66, 22.68, 23.00, 22.37, 24.30],
            'asr': [32.61, 33.33, 38.89, 41.62, 39.89]
        },
        'wrn-28': {
            'accuracy': [23.14, 20.16, 19.70, 23.61, 25.33],
            'fidelity': [22.67, 20.12, 19.38, 23.75, 25.43],
            'asr': [25.00, 30.05, 27.42, 30.34, 28.72]
        },
    },
    'mosafi': {
        'resnet50': {
            'accuracy': [10.00, 13.84, 18.82, 15.86, 14.35],
            'fidelity': [9.91, 13.91, 18.94, 15.98, 14.55],
            'asr': [33.16, 37.04, 32.09, 43.02, 36.96]
        },
        'pyramidnet': {
            'accuracy': [11.52, 10.75, 10.33, 10.00, 10.09],
            'fidelity': [11.42, 10.74, 10.23, 9.91, 10.00],
            'asr': [33.68, 37.04, 34.22, 31.49, 29.89]
        },
        'resnext': {
            'accuracy': [13.57, 17.83, 17.77, 19.15, 17.14],
            'fidelity': [13.48, 17.84, 17.86, 19.19, 17.26],
            'asr': [35.98, 39.68, 33.16, 35.36, 38.59]
        },
        'wrn-28': {
            'accuracy': [15.12, 10.00, 10.74, 16.47, 14.79],
            'fidelity': [15.08, 9.91, 10.63, 16.70, 14.95],
            'asr': [32.28, 34.92, 26.74, 39.11, 29.35]
        },
    },
    'black-box ripper': {
        'resnet50': {
            'accuracy': [10.36, 10.23, 7.96, 9.96, 9.34],
            'fidelity': [10.38, 10.1, 8.06, 10.04, 9.46],
            'asr': [27.13, 36.36, 34.27, 37.57, 37.43]
        },
        'pyramidnet': {
            'accuracy': [9.32, 9.85, 9.66, 9.99, 10.54],
            'fidelity': [9.06, 9.87, 9.52, 10.23, 10.52],
            'asr': [40.43, 42.06, 43.31, 30.39, 40.99]
        },
        'resnext': {
            'accuracy': [10.75, 14.47, 9.53, 10.23, 10.23],
            'fidelity': [10.64, 14.41, 9.56, 10.07, 10.52],
            'asr': [35.11, 35.29, 30.90, 38.67, 35.29]
        },
        'wrn-28': {
            'accuracy': [6.86, 8.68, 12.15, 9.46, 7.99],
            'fidelity': [7.06, 8.93, 12.12, 9.52, 8.14],
            'asr': [31.91, 32.09, 38.20, 36.46, 42.78]
        },
    },
    'activethief': {
        'resnet50': {
            'accuracy': [28.16, 30.01, 28.72, 30.05, 29.06],
            'fidelity': [28.76, 30.54, 29.32, 30.37, 29.23],
            'asr': [31.15, 33.51, 35.16, 36.90, 36.56]
        },
        'pyramidnet': {
            'accuracy': [38.44, 41.56, 41.71, 39.78, 43.93],
            'fidelity': [39.03, 42.63, 42.25, 40.28, 44.58],
            'asr': [37.99, 43.17, 44.86, 36.90, 44.09]
        },
        'resnext': {
            'accuracy': [31.95, 32.35, 32.74, 31.37, 32.68],
            'fidelity': [32.32, 32.99, 33.34, 31.41, 33.16],
            'asr': [35.52, 36.70, 38.38, 37.97, 38.71]
        },
        'wrn-28': {
            'accuracy': [32.84, 33.68, 30.41, 31.64, 32.47],
            'fidelity': [33.21, 34.21, 31.27, 31.91, 32.91],
            'asr': [33.52, 31.38, 35.14, 34.76, 33.87]
        },
    },
    'DFMS': {
        'resnet50': {
            'accuracy': [10.22, 10.35, 11.14, 10.58, 11.48],
            'fidelity': [10.22, 10.35, 11.14, 10.58, 11.48],
            'asr': [23.28, 28.00, 22.04, 25.0, 27.96]
        },
        'pyramidnet': {
            'accuracy': [10.73, 12.07, 12.28, 10.0, 12.48],
            'fidelity': [10.73, 12.07, 12.28, 10.0, 12.48],
            'asr': [35.20, 33.33, 30.85, 32.28, 35.36]
        },
        'resnext': {
            'accuracy': [11.0, 10.11, 9.95, 10.0, 11.69],
            'fidelity': [11.0, 10.11, 9.95, 10.0, 11.69],
            'asr': [38.02, 39.56, 33.69, 37.36, 36.61]
        },
        'wrn-28': {
            'accuracy': [12.28, 9.99, 12.26, 11.47, 12.31],
            'fidelity': [12.28, 9.99, 12.26, 11.47, 12.31],
            'asr': [27.75, 32.79, 37.23, 36.72, 31.67]
        },
    },
    'seeker':{
        'resnet50':{
            'accuracy': [71.04, 70.64, 71.41, 73.05, 70.92],
            'fidelity': [71.98, 71.87, 72.84, 74.26, 72.39],
            'asr':[79.68, 85.47, 79.23, 82.26, 78.09]
        },
        'pyramidnet':{
            'accuracy': [67.28, 66.37, 67.90, 68.58, 67.85],
            'fidelity': [68.42, 67.54, 68.89, 70.04, 68.96],
            'asr':[67.74, 69.95, 72.34, 74.05, 73.37]
        },
        'resnext':{
            'accuracy': [55.41, 56.64, 56.72, 56.98, 58.49], # (3000 rewritten )
            'fidelity': [56.59, 57.71, 57.62, 58.23, 59.58],
            'asr':[46.41, 47.80, 50.00, 49.21, 52.22]
        },
        'wrn-28':{
            'accuracy': [69.95, 72.12, 74.78, 71.93, 71.98],
            'fidelity': [71.17, 73.57, 75.80, 72.99, 73.25],
            'asr':[84.66, 82.02, 85.26, 81.52, 78.92]
        },
    }
}


# Setting 2
cifar100_cifar10_dict = {
    'papernot': {
        'resnet50': {
            'accuracy': [5.55, 5.72, 6.36, 5.28, 5.82],
            'fidelity': [5.52, 5.87, 6.34, 5.48, 6.18],
            'asr': [75.52, 69.28, 80.41, 78.15, 80.71]
        },
        'pyramidnet': {
            'accuracy': [8.15, 7.17, 5.68, 7.48, 5.40],
            'fidelity': [8.38, 7.23, 5.77, 7.44, 5.43],
            'asr': [74.83, 77.14, 77.48, 74.83, 67.09]
        },
        'resnext': {
            'accuracy': [5.41, 5.10, 6.14, 5.41, 5.63],
            'fidelity': [5.40, 5.11, 6.28, 5.51, 5.82],
            'asr': [81.12, 79.08, 77.70, 80.13, 72.78]
        },
        'wrn-28': {
            'accuracy': [5.23, 5.25, 5.84, 5.05, 4.60],
            'fidelity': [5.47, 5.23, 6.19, 5.18, 4.59],
            'asr': [58.74, 62.85, 60.93, 54.97, 49.37]
        },
    },
    'mosafi': {
        'resnet50': {
            'accuracy': [1.00, 0.98, 1.00, 1.00, 1.00],
            'fidelity': [0.96, 0.88, 0.89, 0.96, 0.96],
            'asr': [72.60, 72.73, 77.78, 69.13, 75.48]
        },
        'pyramidnet': {
            'accuracy': [1.00, 1.00, 1.03, 1.09, 0.98],
            'fidelity': [1.13, 0.88, 1.00, 0.96, 0.84],
            'asr': [76.03, 71.43, 72.37, 70.47, 70.97]
        },
        'resnext': {
            'accuracy': [1.40, 1.37, 1.26, 1.02, 1.13],
            'fidelity': [1.08, 1.17, 1.12, 0.99, 1.28],
            'asr': [75.34, 74.03, 75.16, 67.11, 74.19]
        },
        'wrn-28': {
            'accuracy': [1.00, 1.23, 1.00, 1.02, 0.94],
            'fidelity': [0.96, 1.09, 0.96, 0.94, 0.93],
            'asr': [70.55, 72.08, 71.05, 71.14, 72.90]
        },
    },
    'black-box ripper': {
        'resnet50': {
            'accuracy': [1.25, 1.05, 1.12, 0.82, 1.42],
            'fidelity': [1.20, 1.15, 1.20, 0.85, 1.31],
            'asr': [67.53, 66.88, 70.82, 71.71, 71.05]
        },
        'pyramidnet': {
            'accuracy': [1.14, 0.64, 1.23, 0.87, 0.85],
            'fidelity': [1.11, 0.55, 1.39, 0.98, 0.84],
            'asr': [63.64, 70.70, 73.43, 73.29, 71.71]
        },
        'resnext': {
            'accuracy': [1.38, 0.95, 0.58, 1.37, 0.50],
            'fidelity': [1.40, 0.93, 0.45, 1.28, 0.58],
            'asr': [68.18, 66.88, 71.05, 72.40, 73.03]
        },
        'wrn-28': {
            'accuracy': [0.74, 0.99, 1.02, 1.38, 1.05],
            'fidelity': [0.72, 1.02, 1.08, 1.41, 1.00],
            'asr': [63.64, 68.15, 71.47, 73.97, 73.68]
        },
    },
    'activethief': {
        'resnet50': {
            'accuracy': [13.30, 13.33, 15.72, 13.47, 13.42],
            'fidelity': [13.61, 13.27, 15.89, 13.51, 13.80],
            'asr': [70.19, 77.55, 76.87, 77.40, 77.22]
        },
        'pyramidnet': {
            'accuracy': [19.15, 19.99, 20.78, 18.48, ], # 5000
            'fidelity': [19.75, 20.42, 21.37, 18.84, ],
            'asr': [66.46, 69.08, 70.75, 74.66, ]
        },
        'resnext': {
            'accuracy': [13.99, 13.34, 16.00, 14.95, 14.50], #5000
            'fidelity': [14.37, 13.58, 16.16, 15.05, 14.60],
            'asr': [72.08, 70.75, 76.20, 73.58, 72.05]
        },
        'wrn-28': {
            'accuracy': [14.13, 14.56, 17.56, 13.17, 13.52],
            'fidelity': [14.37, 14.94, 17.59, 13.23, 13.63],
            'asr': [57.14, 60.53, 63.16, 60.96, 63.92]
        },
    },
    'DFMS': {
        'resnet50': {
            'accuracy': [1.0, 0.69, 0.82, 0.83, 1.44],
            'fidelity': [1.0, 0.69, 0.82, 0.83, 1.44],
            'asr': [76.62, 74.15, 68.28, 70.32, 79.33]
        },
        'pyramidnet': {
            'accuracy': [1.1, 1.02, 0.98, 1.11, 1.31],
            'fidelity': [1.1, 1.02, 0.98, 1.11, 1.31],
            'asr': [72.37, 75.86, 74.13, 74.12, 72.30]
        },
        'resnext': {
            'accuracy': [1.04, 0.89, 0.98, 0.93, 1.08],
            'fidelity': [1.04, 0.89, 0.98, 0.93, 1.08],
            'asr': [70.39, 70.51, 70.78, 67.57, 71.81]
        },
        'wrn-28': {
            'accuracy': [1.3, 1.28, 1.53, 1.3, 0.98],
            'fidelity': [1.3, 1.28, 1.53, 1.3, 0.98],
            'asr': [73.79, 72.41, 59.60, 65.10, 70.0]
        },
    },
    'seeker':{
        'resnet50': {
            'accuracy': [35.34, 31.68, 33.55, 31.58, 32.14],
            'fidelity': [36.23, 32.70, 34.69, 32.69, 33.40],
            'asr': [86.67, 85.92, 90.01, 86.39, 81.29]
        },
        'pyramidnet': {
            'accuracy': [31.21, 29.42, 30.44, 31.14, 31.00],
            'fidelity': [32.71, 31.08, 32.01, 32.28, 32.28],
            'asr': [84.91, 84.97, 87.16, 87.50, 83.54]
        },
        'resnext': {
            'accuracy': [21.21, 19.05, 20.43, 20.10, 21.01], #(4000 rewritten)
            'fidelity': [21.99, 19.79, 21.62, 20.85, 22.22],
            'asr': [79.73, 81.53, 78.29, 76.97, 74.51]
        },
        'wrn-28': {
            'accuracy': [35.21, 37.52, 36.05, 37.87, 37.01],
            'fidelity': [37.24, 38.99, 37.52, 39.52, 38.47],
            'asr': [87.84, 89.94, 86.71, 84.21, 89.47]
        },
    }
}

# print('================ Setting 1 ================')
# for method_name in cifar10_cifar100_dict.keys():
#     print(f'[{method_name}]')
#     method_dict = cifar10_cifar100_dict[method_name]
#     for arch_name in method_dict.keys():
#         arch_dict = method_dict[arch_name]
#         accuracy = arch_dict['accuracy']
#         fidelity = arch_dict['fidelity']
#         asr = arch_dict['asr']
#         if len(accuracy) == 0:
#             print('no data.')
#             break
#         print(f'    |{arch_name}| accuracy:{np.mean(accuracy):.1f}({np.std(accuracy):.1f})'+
#               f'fidelity:{np.mean(fidelity):.1f}({np.std(fidelity):.1f})'+
#               f'asr:{np.mean(asr):.1f}({np.std(asr):.1f})')
# print('================ Setting 2 ================')
# for method_name in cifar100_cifar10_dict.keys():
#     print(f'[{method_name}]')
#     method_dict = cifar100_cifar10_dict[method_name]
#     for arch_name in method_dict.keys():
#         arch_dict = method_dict[arch_name]
#         accuracy = arch_dict['accuracy']
#         fidelity = arch_dict['fidelity']
#         asr = arch_dict['asr']
#         if len(accuracy) == 0:
#             print('no data.')
#             break
#         print(f'    |{arch_name}| accuracy:{np.mean(accuracy):.1f}({np.std(accuracy):.1f})'+
#               f'fidelity:{np.mean(fidelity):.1f}({np.std(fidelity):.1f})'+
#               f'asr:{np.mean(asr):.1f}({np.std(asr):.1f})')

for method_name in cifar10_cifar100_dict.keys():
    print(f'[{method_name}]')
    method_dict = cifar10_cifar100_dict[method_name]
    for arch_name in method_dict.keys():
        arch_dict = method_dict[arch_name]
        accuracy1 = cifar10_cifar100_dict[method_name][arch_name]['accuracy']
        fidelity1 = cifar10_cifar100_dict[method_name][arch_name]['fidelity']
        asr1 = cifar10_cifar100_dict[method_name][arch_name]['asr']
        accuracy2 = cifar100_cifar10_dict[method_name][arch_name]['accuracy']
        fidelity2 = cifar100_cifar10_dict[method_name][arch_name]['fidelity']
        asr2 = cifar100_cifar10_dict[method_name][arch_name]['asr']
        if len(accuracy1) == 0:
            print('no data.')
            break
        if method_name != 'seeker':
            print(f'    |{arch_name}| {np.mean(accuracy1):.1f} ($\pm${np.std(accuracy1):.1f}) & '+
                  f'{np.mean(fidelity1):.1f} ($\pm${np.std(fidelity1):.1f}) & '+
                  f'{np.mean(asr1):.1f} ($\pm${np.std(asr1):.1f}) & '+
                  f'{np.mean(accuracy2):.1f} ($\pm${np.std(accuracy2):.1f}) & ' +
                  f'{np.mean(fidelity2):.1f} ($\pm${np.std(fidelity2):.1f}) & ' +
                  f'{np.mean(asr2):.1f} ($\pm${np.std(asr2):.1f}) \\\\'
                  )
        else:
            print(f'    |{arch_name}| \\textbf{{{np.mean(accuracy1):.1f} ($\pm${np.std(accuracy1):.1f})}} & ' +
                  f'\\textbf{{{np.mean(fidelity1):.1f} ($\pm${np.std(fidelity1):.1f})}} & ' +
                  f'\\textbf{{{np.mean(asr1):.1f} ($\pm${np.std(asr1):.1f})}} & '+
                  f'\\textbf{{{np.mean(accuracy2):.1f} ($\pm${np.std(accuracy2):.1f})}} & ' +
                  f'\\textbf{{{np.mean(fidelity2):.1f} ($\pm${np.std(fidelity2):.1f})}} & ' +
                  f'\\textbf{{{np.mean(asr2):.1f} ($\pm${np.std(asr2):.1f})}} \\\\'
                  )

# print('================ Setting 2 ================')
# for method_name in cifar100_cifar10_dict.keys():
#     print(f'[{method_name}]')
#     method_dict = cifar100_cifar10_dict[method_name]
#     for arch_name in method_dict.keys():
#         arch_dict = method_dict[arch_name]
#         accuracy = arch_dict['accuracy']
#         fidelity = arch_dict['fidelity']
#         asr = arch_dict['asr']
#         if len(accuracy) == 0:
#             print('no data.')
#             break
#         if method_name != 'seeker':
#             print(f'    |{arch_name}| {np.mean(accuracy):.1f} ($\pm${np.std(accuracy):.1f}) & ' +
#                   f'{np.mean(fidelity):.1f} ($\pm${np.std(fidelity):.1f}) & ' +
#                   f'{np.mean(asr):.1f} ($\pm${np.std(asr):.1f}) \\\\')
#         else:
#             print(f'    |{arch_name}| \\textbf{{{np.mean(accuracy):.1f} ($\pm${np.std(accuracy):.1f})}} & ' +
#                   f'\\textbf{{{np.mean(fidelity):.1f} ($\pm${np.std(fidelity):.1f})}} & ' +
#                   f'\\textbf{{{np.mean(asr):.1f} ($\pm${np.std(asr):.1f})}} \\\\')