import pickle
from os.path import join

def get_dataset_info(dataset_name, remove_h):
    with open(join('data', dataset_name, 'dataset_info.p'), 'rb') as fp:
        dataset_info = pickle.load(fp)
    return dataset_info


# qm9_with_h = {
#     'name': 'qm9',
#     'atom_encoder': {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4},
#     'atom_decoder': ['H', 'C', 'N', 'O', 'F'],
#     'max_n_nodes': 29,
#     'n_nodes': {3: 1, 4: 2, 5: 2, 6: 8, 7: 16, 8: 55, 9: 135, 10: 373, 11: 834, 12: 1680, 13: 3040, 
#                 14: 5059, 15: 7519, 16: 9989, 17: 13340, 18: 12534, 19: 16721, 20: 8780, 21: 13715, 
#                 22: 3081, 23: 6756, 24: 532, 25: 1854, 26: 41, 27: 315, 29: 29},
#     'atom_types': {0: 1008306, 1: 673038, 2: 111052, 3: 149243, 4: 2622},
#     'colors_dic': ['#FFFFFF99', 'C7', 'C0', 'C3', 'C1'],
#     'radius_dic': [0.46, 0.77, 0.77, 0.77, 0.77],
#     'with_h': True
# }

# qm9_without_h = {
#     'name': 'qm9',
#     'atom_encoder': {'C': 0, 'N': 1, 'O': 2, 'F': 3},
#     'atom_decoder': ['C', 'N', 'O', 'F'],
#     'max_n_nodes': 9,
#     'n_nodes': {2: 3, 3: 6, 4: 26, 5: 107, 6: 507, 7: 2525, 8: 14558, 9: 88679},
#     'atom_types': {0: 673038, 1: 111052, 2: 149243, 3: 2622},
#     'colors_dic': ['C7', 'C0', 'C3', 'C1'],
#     'radius_dic': [0.77, 0.77, 0.77, 0.77],
#     'with_h': False
# }

# zinc250k_with_h = {
#     'name': 'zinc250k',
#     'atom_encoder': {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5, 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9},
#     'atom_decoder': ['H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I'],
#     'max_n_nodes': 83,
#     'n_nodes': {9: 1, 10: 3, 11: 7, 12: 19, 13: 21, 14: 48, 15: 63, 16: 100, 17: 141, 18: 162, 19: 186, 20: 275, 
#                 21: 327, 22: 458, 23: 526, 24: 700, 25: 863, 26: 1142, 27: 1344, 28: 1734, 29: 2148, 30: 2608, 
#                 31: 3208, 32: 3796, 33: 4353, 34: 5218, 35: 5835, 36: 6634, 37: 7438, 38: 8230, 39: 8563, 
#                 40: 9358, 41: 10015, 42: 10212, 43: 10347, 44: 10860, 45: 10433, 46: 10086, 47: 9900, 48: 9142, 
#                 49: 8658, 50: 7991, 51: 7297, 52: 6602, 53: 5639, 54: 5000, 55: 4236, 56: 3697, 57: 3039, 
#                 58: 2415, 59: 1996, 60: 1655, 61: 1254, 62: 1082, 63: 785, 64: 616, 65: 472, 66: 330, 67: 248, 
#                 68: 164, 69: 114, 70: 78, 71: 50, 72: 31, 73: 12, 74: 24, 75: 8, 76: 3, 77: 5, 78: 1, 79: 1, 
#                 80: 3, 83: 1},
#     'atom_types': {0: 4542046, 1: 3753634, 2: 621757, 3: 508274, 4: 69973, 5: 116, 6: 90704, 7: 37828, 8: 11241, 9: 795},
#     'colors_dic': ['#FFFFFF99', 'C7', 'C0', 'C3', 'C1', 'C4', 'C8', 'C9', 'C11', 'C12'],
#     'radius_dic': [0.3, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6],
#     'with_h': True
# }

# zinc250k_without_h = {
#     'name': 'zinc250k',
#     'atom_encoder': {'C': 0, 'N': 1, 'O': 2, 'F': 3, 'P': 4, 'S': 5, 'Cl': 6, 'Br': 7, 'I': 8},
#     'atom_decoder': ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I'],
#     'max_n_nodes': 38,
#     'n_nodes': {6: 2, 7: 5, 8: 12, 9: 60, 10: 172, 11: 634, 12: 1042, 13: 1550, 14: 2454, 15: 3828, 16: 5554, 
#                 17: 7598, 18: 10252, 19: 12791, 20: 15683, 21: 18075, 22: 16394, 23: 18530, 24: 20444, 25: 20151, 
#                 26: 17047, 27: 13878, 28: 8863, 29: 6857, 30: 5366, 31: 4232, 32: 3311, 33: 2308, 34: 1551, 35: 902, 
#                 36: 347, 37: 117, 38: 1},
#     'atom_types': {0: 3753634, 1: 621757, 2: 508274, 3: 69973, 4: 116, 5: 90704, 6: 37828, 7: 11241, 8: 795},
#     'colors_dic': ['C7', 'C0', 'C3', 'C1', 'C4', 'C8', 'C9', 'C11', 'C12'],
#     'radius_dic': [0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6],
#     'with_h': False
# }

# zinc250k_explicitH = {
#     'name': 'zinc250k_explicitH',
#     'atom_encoder': {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5, 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9},
#     'atom_decoder': ['H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I'],
#     'allowed_bonds': [1,   4,   3,   2,   1,   5,   4,   1,    1,    1], # P is actually 3 or 5
#     'max_n_nodes': 40,
#     'n_nodes': {6: 1, 7: 4, 8: 12, 9: 54, 10: 135, 11: 417, 12: 734, 13: 1062, 14: 1735, 15: 2791, 16: 4060, 
#                 17: 6082, 18: 8622, 19: 11538, 20: 14581, 21: 17568, 22: 18563, 23: 19829, 24: 20712, 25: 20352, 26: 17694, 
#                 27: 14900, 28: 10546, 29: 7861, 30: 5952, 31: 4545, 32: 3555, 33: 2589, 34: 1746, 35: 1057, 36: 473, 37: 183, 
#                 38: 53, 39: 4, 40: 1},
#     'atom_types': {0: 101282, 1: 3753634, 2: 621757, 3: 508274, 4: 69973, 5: 116, 6: 90704, 7: 37828, 8: 11241, 9: 795},
#     'colors_dic': ['#FFFFFF99', 'C7', 'C0', 'C3', 'C1', 'C4', 'C8', 'C9', 'C11', 'C12'],
#     'radius_dic': [0.3, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6],
#     'with_h': True
# }


#     # if dataset_name == 'qm9':
#     #     if not remove_h:
#     #         return qm9_with_h
#     #     else:
#     #         return qm9_without_h
#     # elif dataset_name == 'zinc250k':
#     #     if not remove_h:
#     #         return zinc250k_with_h
#     #     else:
#     #         return zinc250k_without_h
#     # elif 'zinc250k_explicitH' in  dataset_name:# == 'zinc250k_explicitH' or dataset_name == 'zinc250k_explicitH_penalized_logP':
#     #     return zinc250k_explicitH
#     # else:
#     #     raise Exception("Wrong dataset %s" % dataset_name)


# # OLD STUFF
# # qm9_with_h = {
# #     'name': 'qm9',
# #     'atom_encoder': {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4},
# #     'atom_decoder': ['H', 'C', 'N', 'O', 'F'],
# #     'n_nodes': {22: 3393, 17: 13025, 23: 4848, 21: 9970, 19: 13832, 20: 9482, 16: 10644, 13: 3060,
# #                 15: 7796, 25: 1506, 18: 13364, 12: 1689, 11: 807, 24: 539, 14: 5136, 26: 48, 7: 16, 10: 362,
# #                 8: 49, 9: 124, 27: 266, 4: 4, 29: 25, 6: 9, 5: 5, 3: 1},
# #     'max_n_nodes': 29,
# #     'atom_types': {1: 635559, 2: 101476, 0: 923537, 3: 140202, 4: 2323},
# #     # 'distances': [903054, 307308, 111994, 57474, 40384, 29170, 47152, 414344, 2202212, 573726,
# #     #               1490786, 2970978, 756818, 969276, 489242, 1265402, 4587994, 3187130, 2454868, 2647422,
# #     #               2098884,
# #     #               2001974, 1625206, 1754172, 1620830, 1710042, 2133746, 1852492, 1415318, 1421064, 1223156,
# #     #               1322256,
# #     #               1380656, 1239244, 1084358, 981076, 896904, 762008, 659298, 604676, 523580, 437464, 413974,
# #     #               352372,
# #     #               291886, 271948, 231328, 188484, 160026, 136322, 117850, 103546, 87192, 76562, 61840,
# #     #               49666, 43100,
# #     #               33876, 26686, 22402, 18358, 15518, 13600, 12128, 9480, 7458, 5088, 4726, 3696, 3362, 3396,
# #     #               2484,
# #     #               1988, 1490, 984, 734, 600, 456, 482, 378, 362, 168, 124, 94, 88, 52, 44, 40, 18, 16, 8, 6,
# #     #               2,
# #     #               0, 0, 0, 0,
# #     #               0,
# #     #               0, 0],
# #     'colors_dic': ['#FFFFFF99', 'C7', 'C0', 'C3', 'C1'],
# #     'radius_dic': [0.46, 0.77, 0.77, 0.77, 0.77],
# #     'with_h': True}
# #     # 'bond1_radius': {'H': 31, 'C': 76, 'N': 71, 'O': 66, 'F': 57},
# #     # 'bond1_stdv': {'H': 5, 'C': 2, 'N': 2, 'O': 2, 'F': 3},
# #     # 'bond2_radius': {'H': -1000, 'C': 67, 'N': 60, 'O': 57, 'F': 59},
# #     # 'bond3_radius': {'H': -1000, 'C': 60, 'N': 54, 'O': 53, 'F': 53}}

# # qm9_without_h = {
# #     'name': 'qm9',
# #     'atom_encoder': {'C': 0, 'N': 1, 'O': 2, 'F': 3},
# #     'atom_decoder': ['C', 'N', 'O', 'F'],
# #     'max_n_nodes': 29,
# #     'n_nodes': {9: 83366, 8: 13625, 7: 2404, 6: 475, 5: 91, 4: 25, 3: 7, 1: 2, 2: 5},
# #     'atom_types': {0: 635559, 2: 140202, 1: 101476, 3: 2323},
# #     'distances': [594, 1232, 3706, 4736, 5478, 9156, 8762, 13260, 45674, 174676, 469292,
# #                     1182942, 126722, 25768, 28532, 51696, 232014, 299916, 686590, 677506,
# #                     379264, 162794, 158732, 156404, 161742, 156486, 236176, 310918, 245558,
# #                     164688, 98830, 81786, 89318, 91104, 92788, 83772, 81572, 85032, 56296,
# #                     32930, 22640, 24124, 24010, 22120, 19730, 21968, 18176, 12576, 8224,
# #                     6772,
# #                     3906, 4416, 4306, 4110, 3700, 3592, 3134, 2268, 774, 674, 514, 594, 622,
# #                     672, 642, 472, 300, 170, 104, 48, 54, 78, 78, 56, 48, 36, 26, 4, 2, 4,
# #                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# #     'colors_dic': ['C7', 'C0', 'C3', 'C1'],
# #     'radius_dic': [0.77, 0.77, 0.77, 0.77],
# #     'with_h': False}
# #     # 'bond1_radius': {'C': 76, 'N': 71, 'O': 66, 'F': 57},
# #     # 'bond1_stdv': {'C': 2, 'N': 2, 'O': 2, 'F': 3},
# #     # 'bond2_radius': {'C': 67, 'N': 60, 'O': 57, 'F': 59},
# #     # 'bond3_radius': {'C': 60, 'N': 54, 'O': 53, 'F': 53}}


# # qm9_second_half = {
# #     'name': 'qm9_second_half',
# #     'atom_encoder': {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4},
# #     'atom_decoder': ['H', 'C', 'N', 'O', 'F'],
# #     'n_nodes': {19: 6944, 12: 845, 20: 4794, 21: 4962, 27: 132, 25: 754, 18: 6695, 14: 2587, 15: 3865, 22: 1701, 17: 6461, 16: 5344, 23: 2380, 13: 1541, 24: 267, 10: 178, 7: 7, 11: 412, 8: 25, 9: 62, 29: 15, 26: 17, 4: 3, 3: 1, 6: 5, 5: 3},
# #     'atom_types': {1: 317604, 2: 50852, 3: 70033, 0: 461622, 4: 1164},
# #     'distances': [457374, 153688, 55626, 28284, 20414, 15010, 24412, 208012, 1105440, 285830, 748876, 1496486, 384178, 484194, 245688, 635534, 2307642, 1603762, 1231044, 1329758, 1053612, 1006742, 813504, 880670, 811616, 855082, 1066434, 931672, 709810, 711032, 608446, 660538, 692382, 619084, 544200, 490740, 450576, 380662, 328150, 303008, 263888, 218820, 207414, 175452, 145636, 135646, 116184, 94622, 80358, 68230, 58706, 51216, 44020, 38212, 30492, 24886, 21210, 17270, 13056, 11156, 9082, 7534, 6958, 6060, 4632, 3760, 2500, 2342, 1816, 1726, 1768, 1102, 974, 670, 474, 446, 286, 246, 242, 156, 176, 90, 66, 66, 38, 28, 24, 14, 10, 2, 6, 0, 2, 0, 0, 0, 0, 0, 0, 0],
# #     'colors_dic': ['#FFFFFF99', 'C7', 'C0', 'C3', 'C1'],
# #     'radius_dic': [0.46, 0.77, 0.77, 0.77, 0.77],
# #     'max_n_nodes': 29,
# #     'with_h': True}
# #     # 'bond1_radius': {'H': 31, 'C': 76, 'N': 71, 'O': 66, 'F': 57},
# #     # 'bond1_stdv': {'H': 5, 'C': 2, 'N': 2, 'O': 2, 'F': 3},
# #     # 'bond2_radius': {'H': -1000, 'C': 67, 'N': 60, 'O': 57, 'F': 59},
# #     # 'bond3_radius': {'H': -1000, 'C': 60, 'N': 54, 'O': 53, 'F': 53}}


# # geom_with_h = {
# #     'name': 'geom',
# #     'atom_encoder': {'H': 0, 'B': 1, 'C': 2, 'N': 3, 'O': 4, 'F': 5, 'Al': 6, 'Si': 7,
# #     'P': 8, 'S': 9, 'Cl': 10, 'As': 11, 'Br': 12, 'I': 13, 'Hg': 14, 'Bi': 15},
# #     'atomic_nb': [1,  5,  6,  7,  8,  9, 13, 14, 15, 16, 17, 33, 35, 53, 80, 83],
# #     'atom_decoder': ['H', 'B', 'C', 'N', 'O', 'F', 'Al', 'Si', 'P', 'S', 'Cl', 'As', 'Br', 'I', 'Hg', 'Bi'],
# #     'max_n_nodes': 181,
# #     'n_nodes': {3: 1, 4: 3, 5: 9, 6: 2, 7: 8, 8: 23, 9: 23, 10: 50, 11: 109, 12: 168, 13: 280, 14: 402, 15: 583, 16: 597,
# #                 17: 949, 18: 1284, 19: 1862, 20: 2674, 21: 3599, 22: 6109, 23: 8693, 24: 13604, 25: 17419, 26: 25672,
# #                 27: 31647, 28: 43809, 29: 56697, 30: 70400, 31: 82655, 32: 104100, 33: 122776, 34: 140834, 35: 164888,
# #                 36: 185451, 37: 194541, 38: 218549, 39: 231232, 40: 243300, 41: 253349, 42: 268341, 43: 272081,
# #                 44: 276917, 45: 276839, 46: 274747, 47: 272126, 48: 262709, 49: 250157, 50: 244781, 51: 228898,
# #                 52: 215338, 53: 203728, 54: 191697, 55: 180518, 56: 163843, 57: 152055, 58: 136536, 59: 120393,
# #                 60: 107292, 61: 94635, 62: 83179, 63: 68384, 64: 61517, 65: 48867, 66: 37685, 67: 32859, 68: 27367,
# #                 69: 20981, 70: 18699, 71: 14791, 72: 11921, 73: 9933, 74: 9037, 75: 6538, 76: 6374, 77: 4036, 78: 4189,
# #                 79: 3842, 80: 3277, 81: 2925, 82: 1843, 83: 2060, 84: 1394, 85: 1514, 86: 1357, 87: 1346, 88: 999,
# #                 89: 300, 90: 390, 91: 510, 92: 510, 93: 240, 94: 721, 95: 360, 96: 360, 97: 390, 98: 330, 99: 540,
# #                 100: 258, 101: 210, 102: 60, 103: 180, 104: 206, 105: 60, 106: 390, 107: 180, 108: 180, 109: 150,
# #                 110: 120, 111: 360, 112: 120, 113: 210, 114: 60, 115: 30, 116: 210, 117: 270, 118: 450, 119: 240,
# #                 120: 228, 121: 120, 122: 30, 123: 420, 124: 240, 125: 210, 126: 158, 127: 180, 128: 60, 129: 30,
# #                 130: 120, 131: 30, 132: 120, 133: 60, 134: 240, 135: 169, 136: 240, 137: 30, 138: 270, 139: 180,
# #                 140: 270, 141: 150, 142: 60, 143: 60, 144: 240, 145: 180, 146: 150, 147: 150, 148: 90, 149: 90,
# #                 151: 30, 152: 60, 155: 90, 159: 30, 160: 60, 165: 30, 171: 30, 175: 30, 176: 60, 181: 30},
# #     'atom_types':{0: 143905848, 1: 290, 2: 129988623, 3: 20266722, 4: 21669359, 5: 1481844, 6: 1,
# #                   7: 250, 8: 36290, 9: 3999872, 10: 1224394, 11: 4, 12: 298702, 13: 5377, 14: 13, 15: 34},
# #     'colors_dic': ['#FFFFFF99',
# #                    'C2', 'C7', 'C0', 'C3', 'C1', 'C5',
# #                    'C6', 'C4', 'C8', 'C9', 'C10',
# #                    'C11', 'C12', 'C13', 'C14'],
# #     'radius_dic': [0.3, 0.6, 0.6, 0.6, 0.6,
# #                    0.6, 0.6, 0.6, 0.6, 0.6,
# #                    0.6, 0.6, 0.6, 0.6, 0.6,
# #                    0.6],
# #     'with_h': True}


# # geom_no_h = {
# #     'name': 'geom',
# #     'atom_encoder': {'B': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Al': 5, 'Si': 6, 'P': 7, 'S': 8, 'Cl': 9, 'As': 10,
# #                      'Br': 11, 'I': 12, 'Hg': 13, 'Bi': 14},
# #     'atomic_nb': [5,  6,  7,  8,  9, 13, 14, 15, 16, 17, 33, 35, 53, 80, 83],
# #     'atom_decoder': ['B', 'C', 'N', 'O', 'F', 'Al', 'Si', 'P', 'S', 'Cl', 'As', 'Br', 'I', 'Hg', 'Bi'],
# #     'max_n_nodes': 91,
# #     'n_nodes': {1: 3, 2: 5, 3: 8, 4: 89, 5: 166, 6: 370, 7: 613, 8: 1214, 9: 1680, 10: 3315, 11: 5115, 12: 9873,
# #                 13: 15422, 14: 28088, 15: 50643, 16: 82299, 17: 124341, 18: 178417, 19: 240446, 20: 308209, 21: 372900,
# #                 22: 429257, 23: 477423, 24: 508377, 25: 522385, 26: 522000, 27: 507882, 28: 476702, 29: 426308,
# #                 30: 375819, 31: 310124, 32: 255179, 33: 204441, 34: 149383, 35: 109343, 36: 71701, 37: 44050,
# #                 38: 31437, 39: 20242, 40: 14971, 41: 10078, 42: 8049, 43: 4476, 44: 3130, 45: 1736, 46: 2030,
# #                 47: 1110, 48: 840, 49: 750, 50: 540, 51: 810, 52: 591, 53: 453, 54: 540, 55: 720, 56: 300, 57: 360,
# #                 58: 714, 59: 390, 60: 519, 61: 210, 62: 449, 63: 210, 64: 289, 65: 589, 66: 227, 67: 180, 68: 330,
# #                 69: 330, 70: 150, 71: 60, 72: 210, 73: 60, 74: 180, 75: 120, 76: 30, 77: 150, 78: 30, 79: 60, 82: 60,
# #                 85: 60, 86: 6, 87: 60, 90: 60, 91: 30},
# #     'atom_types': {0: 290, 1: 129988623, 2: 20266722, 3: 21669359, 4: 1481844, 5: 1, 6: 250, 7: 36290, 8: 3999872,
# #                    9: 1224394, 10: 4, 11: 298702, 12: 5377, 13: 13, 14: 34},
# #     'colors_dic': ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14'],
# #     'radius_dic': [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
# #     'with_h': False}


# # def get_dataset_info(dataset_name, remove_h):
# #     if dataset_name == 'qm9':
# #         if not remove_h:
# #             return qm9_with_h
# #         else:
# #             return qm9_without_h
# #     elif dataset_name == 'geom':
# #         if not remove_h:
# #             return geom_with_h
# #         else:
# #             raise Exception('Missing config for %s without hydrogens' % dataset_name)
# #     elif dataset_name == 'qm9_second_half':
# #         if not remove_h:
# #             return qm9_second_half
# #         else:
# #             raise Exception('Missing config for %s without hydrogens' % dataset_name)
# #     else:
# #         raise Exception("Wrong dataset %s" % dataset_name)
