# SVHN 
import torch



num_classes = 10
idx_mapping = []
for i in range(num_classes):
    idx_mapping.append(i)



def outlier_detection(l1_norm_list, idx_mapping):
    print("-" * 30)
    print("Determining whether model is backdoor")
    consistency_constant = 1.4826
    median = torch.median(l1_norm_list)
    mad = consistency_constant * torch.median(torch.abs(l1_norm_list - median))
    min_mad = torch.abs(torch.min(l1_norm_list) - median) / mad

    print("Median: {}, MAD: {}".format(median, mad))
    print("Anomaly index: {}".format(min_mad))

    if min_mad < 2:
        print("Not a backdoor model")
    else:
        print("This is a backdoor model")

    flag_list = []
    for y_label in idx_mapping:
        if l1_norm_list[idx_mapping[y_label]] > median:
            continue
        if torch.abs(l1_norm_list[idx_mapping[y_label]] - median) / mad > 2:
            flag_list.append((y_label, l1_norm_list[idx_mapping[y_label]]))

    if len(flag_list) > 0:
        flag_list = sorted(flag_list, key=lambda x: x[1])

    print(
        "Flagged label list: {}".format(",".join(["{}: {}".format(y_label, l_norm) for y_label, l_norm in flag_list]))
    )


if __name__ == '__main__':
    num_classes = 10
    idx_mapping = []
    for i in range(num_classes):
        idx_mapping.append(i)
    # # SVHN_benign
    # l1_norm_list = torch.tensor([102.0990,  80.7132,  56.5296,  57.2755,  73.8205,  60.1620,  64.0545, 79.5832,  73.7145,  67.7726])
    # # SVHN dirty-label
    # l1_norm_list = torch.tensor([55.9934, 61.7941, 50.1486, 42.1746, 55.5424, 47.5258, 57.4964, 68.3007, 61.7859, 61.7190])
    # # SVHN clean-label
    l1_norm_list = torch.tensor([96.2974, 84.6181, 60.0090, 49.9106, 69.2450, 56.6759, 68.1424, 80.5633, 69.3771, 80.7085])

    
    # # CIFAR10 benign
    # l1_norm_list = torch.tensor([98.7799, 108.4121,  86.6648,  98.8678,  90.6469,  94.2850, 106.8170, 110.3739, 120.6416, 121.9175])
    # # CIFAR10 dirty-label
    # l1_norm_list = torch.tensor([72.3736,  78.2265,  79.4564,  89.4045,  79.4408,  74.8779, 105.4309, 97.2447, 101.1340,  95.9949])
    # # CIFAR10 clean-label
    # l1_norm_list = torch.tensor([103.8022,  92.5423,  99.1087,  95.5599,  92.7839,  99.8720,  99.9603, 101.3849, 117.9083, 102.4376])
    
    # # GTSRB benign
    l1_norm_list = torch.tensor([66.4830,  66.0538,  56.6743,  71.6981,  70.3372,  70.8646,  96.0844,
         65.3540,  60.6505,  68.1037,  66.3698,  68.1835,  63.3611,  70.1179,
         73.4789,  70.2912,  65.4629,  67.3621,  70.0752,  90.6814,  82.2022,
         74.7106,  87.3538,  70.9720,  81.2701,  83.4412,  76.8413,  79.9171,
         77.5419,  75.8940,  80.7885,  74.8114, 106.1047,  78.3991,  86.1934,
         83.8910,  79.4908,  99.8801,  69.7259,  68.2080,  89.0258,  95.9697,
         79.0114])
    # GTSRB dirty-label
    l1_norm_list = torch.tensor([57.3550, 48.6016, 49.9472, 52.7324, 50.1553, 51.2774, 59.7646, 50.7625,
        41.6229, 45.5935, 41.8993, 46.8380, 43.4334, 52.9589, 41.8286, 58.2637,
        49.3023, 40.2236, 51.5371, 60.7760, 57.7853, 54.2026, 66.4855, 55.6490,
        51.2333, 55.4518, 53.7215, 60.0815, 55.0573, 60.0157, 50.3932, 54.7874,
        71.1637, 59.8522, 47.9805, 61.5787, 63.1259, 56.6095, 47.9876, 58.7299,
        59.0152, 63.2370, 60.0389])
    
    
    # # GTSRB clean-label
    l1_norm_list = torch.tensor([71.6703,  62.3719,  64.6746,  83.2009,  66.1969,  71.8759, 104.3304,
         65.6223,  74.7942,  69.5948,  70.8438,  67.8519,  70.8000,  70.5492,
         76.0404,  81.4476,  64.7689,  64.6797,  89.2380,  81.0475, 108.9571,
         77.4265,  79.1142,  78.3208,  83.4693,  82.6501,  82.0574,  81.4254,
         83.4702,  80.0919,  87.7846,  88.6663,  98.3407,  61.3899,  85.4673,
         78.4243,  85.2204,  85.7566,  63.3579,  67.6281,  68.0140,  89.7042,
         91.3552])
    

    outlier_detection(l1_norm_list, idx_mapping)