import torch

def normalize_adjacency(adj):
    row_sums = torch.sum(adj, dim=1)
    D_matrix = torch.diag(row_sums)    
    inv_sqrt_diagonal = 1.0 / torch.sqrt(D_matrix)
    inv_sqrt_diagonal = torch.where(inv_sqrt_diagonal == float('inf'), torch.tensor(0.0), inv_sqrt_diagonal)
    return inv_sqrt_diagonal@adj.float()@inv_sqrt_diagonal

def get_normalization_factors(dataset):
    
    
    
    if dataset == 'zpn/clintox':
        
        means = torch.tensor([7.7646e-05, 0.0000e+00, 0.0000e+00, 7.7646e-05, 7.1941e-01, 9.3046e-02,
            1.5325e-01, 9.6281e-03, 0.0000e+00, 0.0000e+00, 2.5882e-04, 7.7646e-05,
            1.4235e-03, 1.1000e-02, 8.1269e-03, 0.0000e+00, 2.5882e-05, 2.5882e-05,
            0.0000e+00, 2.5882e-05, 2.5882e-05, 2.5882e-05, 0.0000e+00, 0.0000e+00,
            2.5882e-05, 5.1764e-05, 0.0000e+00, 5.1764e-05, 5.1764e-05, 3.8823e-04,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 2.5882e-05, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.7176e-03, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.1764e-05, 2.5882e-05,
            5.1764e-05, 2.5882e-05, 0.0000e+00, 2.5882e-05, 2.5882e-05, 1.5193e-02,
            9.6462e-01, 2.0110e-02, 2.5882e-05, 2.5882e-05, 2.8470e-04, 2.3087e-01,
            4.3432e-01, 3.0572e-01, 2.8781e-02, 0.0000e+00, 2.5882e-05, 9.2111e-01,
            3.7581e-02, 4.1308e-02, 4.1095e-01, 3.2681e-01, 1.8003e-01, 8.2201e-02,
            0.0000e+00, 0.0000e+00, 7.7646e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            7.7646e-05, 7.1941e-01, 9.3046e-02, 1.5325e-01, 9.6281e-03, 0.0000e+00,
            0.0000e+00, 2.5882e-04, 7.7646e-05, 1.3976e-03, 2.5882e-05, 1.1000e-02,
            8.1269e-03, 0.0000e+00, 2.5882e-05, 2.5882e-05, 0.0000e+00, 2.5882e-05,
            2.5882e-05, 2.5882e-05, 0.0000e+00, 0.0000e+00, 2.5882e-05, 5.1764e-05,
            0.0000e+00, 5.1764e-05, 5.1764e-05, 3.8823e-04, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 2.5882e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 2.5882e-05, 2.5623e-03, 1.2941e-04, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.1764e-05, 2.5882e-05,
            5.1764e-05, 2.5882e-05, 0.0000e+00, 0.0000e+00, 2.5882e-05, 6.9410e-01,
            3.0590e-01, 2.5882e-05, 1.2941e-04, 3.6752e-03, 5.4756e-01, 4.4851e-01,
            2.5882e-05, 7.7646e-05])
        
        std_devs = torch.tensor([0.0088, 0.0000, 0.0000, 0.0088, 0.4493, 0.2905, 0.3602, 0.0977, 0.0000,
            0.0000, 0.0161, 0.0088, 0.0377, 0.1043, 0.0898, 0.0000, 0.0051, 0.0051,
            0.0000, 0.0051, 0.0051, 0.0051, 0.0000, 0.0000, 0.0051, 0.0072, 0.0000,
            0.0072, 0.0072, 0.0197, 0.0000, 0.0000, 0.0000, 0.0051, 0.0000, 0.0000,
            0.0000, 0.0000, 0.0000, 0.0000, 0.0521, 0.0000, 0.0000, 0.0000, 0.0000,
            0.0000, 0.0072, 0.0051, 0.0072, 0.0051, 0.0000, 0.0051, 0.0051, 0.1223,
            0.1847, 0.1404, 0.0051, 0.0051, 0.0169, 0.4214, 0.4957, 0.4607, 0.1672,
            0.0000, 0.0051, 0.2696, 0.1902, 0.1990, 0.4920, 0.4691, 0.3842, 0.2747,
            0.0000, 0.0000, 0.0088, 0.0000, 0.0000, 0.0000, 0.0088, 0.4493, 0.2905,
            0.3602, 0.0977, 0.0000, 0.0000, 0.0161, 0.0088, 0.0374, 0.0051, 0.1043,
            0.0898, 0.0000, 0.0051, 0.0051, 0.0000, 0.0051, 0.0051, 0.0051, 0.0000,
            0.0000, 0.0051, 0.0072, 0.0000, 0.0072, 0.0072, 0.0197, 0.0000, 0.0000,
            0.0000, 0.0051, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0051,
            0.0506, 0.0114, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0072, 0.0051,
            0.0072, 0.0051, 0.0000, 0.0000, 0.0051, 0.4608, 0.4608, 0.0051, 0.0114,
            0.0605, 0.4977, 0.4973, 0.0051, 0.0088])
    
    
    if dataset == 'zpn/bbbp':
        means = torch.tensor([8.1520e-04, 0.0000e+00, 0.0000e+00, 2.0380e-05, 7.4582e-01, 8.9488e-02,
            1.2746e-01, 1.0373e-02, 4.2798e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            2.2418e-04, 1.1617e-02, 1.2595e-02, 0.0000e+00, 2.0380e-05, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0190e-03,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.2228e-04, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.5475e-03,
            9.9495e-01, 2.4863e-03, 2.0380e-05, 0.0000e+00, 2.8532e-03, 2.0688e-01,
            4.5156e-01, 3.0780e-01, 3.0916e-02, 0.0000e+00, 0.0000e+00, 9.4049e-01,
            2.9979e-02, 2.9530e-02, 4.0937e-01, 3.3297e-01, 1.7869e-01, 7.8972e-02,
            0.0000e+00, 0.0000e+00, 8.1520e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            2.0380e-05, 7.4582e-01, 8.9488e-02, 1.2746e-01, 1.0373e-02, 4.2798e-04,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 2.2418e-04, 0.0000e+00, 1.1617e-02,
            1.2595e-02, 0.0000e+00, 2.0380e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0190e-03, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 1.2228e-04, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
            0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.6754e-01,
            3.3246e-01, 0.0000e+00, 1.2636e-03, 2.0584e-03, 5.4989e-01, 4.4679e-01,
            0.0000e+00, 0.0000e+00])
        std_devs = torch.tensor([0.0285, 0.0000, 0.0000, 0.0045, 0.4354, 0.2854, 0.3335, 0.1013, 0.0207,
            0.0000, 0.0000, 0.0000, 0.0150, 0.1072, 0.1115, 0.0000, 0.0045, 0.0000,
            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
            0.0000, 0.0000, 0.0319, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
            0.0000, 0.0000, 0.0000, 0.0000, 0.0111, 0.0000, 0.0000, 0.0000, 0.0000,
            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0504,
            0.0709, 0.0498, 0.0045, 0.0000, 0.0533, 0.4051, 0.4977, 0.4616, 0.1731,
            0.0000, 0.0000, 0.2366, 0.1705, 0.1693, 0.4917, 0.4713, 0.3831, 0.2697,
            0.0000, 0.0000, 0.0285, 0.0000, 0.0000, 0.0000, 0.0045, 0.4354, 0.2854,
            0.3335, 0.1013, 0.0207, 0.0000, 0.0000, 0.0000, 0.0150, 0.0000, 0.1072,
            0.1115, 0.0000, 0.0045, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0319, 0.0000, 0.0000,
            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
            0.0111, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4711, 0.4711, 0.0000, 0.0355,
            0.0453, 0.4975, 0.4972, 0.0000, 0.0000])
    if dataset == 'zpn/tox21_srp53':
        means = torch.tensor([4.8124e-05, 1.3750e-05, 6.8748e-06, 1.7874e-04, 7.3223e-01, 7.2068e-02,
            1.4881e-01, 1.1192e-02, 6.1873e-05, 1.3750e-05, 1.5812e-04, 6.4623e-04,
            1.8493e-03, 1.2574e-02, 1.5193e-02, 2.0624e-05, 2.7499e-05, 2.0624e-05,
            6.8748e-06, 6.1873e-05, 2.0624e-05, 4.8124e-05, 2.7499e-05, 2.7499e-05,
            3.4374e-05, 6.1873e-05, 6.8748e-06, 8.2497e-05, 4.1249e-05, 2.6193e-03,
            6.8748e-06, 1.3750e-05, 6.8748e-06, 0.0000e+00, 6.8748e-06, 6.8748e-06,
            2.0624e-05, 2.0624e-05, 1.3750e-04, 2.7499e-05, 1.3612e-03, 2.7499e-05,
            6.8748e-06, 1.3750e-05, 6.8748e-06, 6.8748e-06, 2.0624e-05, 3.4374e-05,
            9.6247e-05, 6.8748e-06, 6.8748e-06, 1.3750e-05, 6.8748e-06, 7.6654e-03,
            9.8674e-01, 5.3005e-03, 2.3374e-04, 5.4998e-05, 4.7436e-04, 2.4342e-01,
            4.6190e-01, 2.6657e-01, 2.7623e-02, 6.8748e-06, 6.8748e-06, 9.5536e-01,
            2.1607e-02, 2.3037e-02, 4.0771e-01, 3.1473e-01, 1.8739e-01, 9.0149e-02,
            6.8748e-06, 6.8748e-06, 6.8748e-06, 4.1249e-05, 1.3750e-05, 6.8748e-06,
            1.7874e-04, 7.3223e-01, 7.2068e-02, 1.4881e-01, 1.1192e-02, 6.1873e-05,
            1.3750e-05, 1.5812e-04, 6.4623e-04, 1.8493e-03, 0.0000e+00, 1.2574e-02,
            1.5193e-02, 2.0624e-05, 2.7499e-05, 2.0624e-05, 6.8748e-06, 6.1873e-05,
            2.0624e-05, 4.8124e-05, 2.7499e-05, 2.7499e-05, 3.4374e-05, 6.1873e-05,
            6.8748e-06, 8.2497e-05, 4.1249e-05, 2.6193e-03, 6.8748e-06, 1.3750e-05,
            6.8748e-06, 0.0000e+00, 6.8748e-06, 6.8748e-06, 2.0624e-05, 2.0624e-05,
            1.3750e-04, 2.7499e-05, 0.0000e+00, 1.3612e-03, 0.0000e+00, 2.7499e-05,
            6.8748e-06, 1.3750e-05, 6.8748e-06, 6.8748e-06, 2.0624e-05, 3.4374e-05,
            9.6247e-05, 0.0000e+00, 6.8748e-06, 6.8748e-06, 1.3750e-05, 6.6509e-01,
            3.3491e-01, 4.8124e-05, 2.9562e-04, 4.7642e-03, 5.5718e-01, 4.3746e-01,
            1.9249e-04, 6.1873e-05])
        
        std_devs = torch.tensor([0.0069, 0.0037, 0.0026, 0.0134, 0.4428, 0.2586, 0.3559, 0.1052, 0.0079,
            0.0037, 0.0126, 0.0254, 0.0430, 0.1114, 0.1223, 0.0045, 0.0052, 0.0045,
            0.0026, 0.0079, 0.0045, 0.0069, 0.0052, 0.0052, 0.0059, 0.0079, 0.0026,
            0.0091, 0.0064, 0.0511, 0.0026, 0.0037, 0.0026, 0.0000, 0.0026, 0.0026,
            0.0045, 0.0045, 0.0117, 0.0052, 0.0369, 0.0052, 0.0026, 0.0037, 0.0026,
            0.0026, 0.0045, 0.0059, 0.0098, 0.0026, 0.0026, 0.0037, 0.0026, 0.0872,
            0.1144, 0.0726, 0.0153, 0.0074, 0.0218, 0.4291, 0.4985, 0.4422, 0.1639,
            0.0026, 0.0026, 0.2065, 0.1454, 0.1500, 0.4914, 0.4644, 0.3902, 0.2864,
            0.0026, 0.0026, 0.0026, 0.0064, 0.0037, 0.0026, 0.0134, 0.4428, 0.2586,
            0.3559, 0.1052, 0.0079, 0.0037, 0.0126, 0.0254, 0.0430, 0.0000, 0.1114,
            0.1223, 0.0045, 0.0052, 0.0045, 0.0026, 0.0079, 0.0045, 0.0069, 0.0052,
            0.0052, 0.0059, 0.0079, 0.0026, 0.0091, 0.0064, 0.0511, 0.0026, 0.0037,
            0.0026, 0.0000, 0.0026, 0.0026, 0.0045, 0.0045, 0.0117, 0.0052, 0.0000,
            0.0369, 0.0000, 0.0052, 0.0026, 0.0037, 0.0026, 0.0026, 0.0045, 0.0059,
            0.0098, 0.0000, 0.0026, 0.0026, 0.0037, 0.4720, 0.4720, 0.0069, 0.0172,
            0.0689, 0.4967, 0.4961, 0.0139, 0.0079])
    return means, std_devs

def normalize_features(fts, dataset):
    
    means, std_devs = get_normalization_factors(dataset)
    means, std_devs = means.cuda(), std_devs.cuda()
    normalized_fts = (fts.cuda() - means) / std_devs
    normalized_fts = torch.nan_to_num(normalized_fts, nan=0.0, posinf=0.0, neginf=0.0)
    return normalized_fts