# Bond lengths from:
# http://www.wiredchemist.com/chemistry/data/bond_energies_lengths.html
# And:
# http://chemistry-reference.com/tables/Bond%20Lengths%20and%20Enthalpies.pdf
bonds1 = {'H': {'H': 74, 'C': 109, 'N': 101, 'O': 96, 'F': 92,
                'B': 119, 'Si': 148, 'P': 144, 'As': 152, 'S': 134,
                'Cl': 127, 'Br': 141, 'I': 161},
          'C': {'H': 109, 'C': 154, 'N': 147, 'O': 143, 'F': 135,
                'Si': 185, 'P': 184, 'S': 182, 'Cl': 177, 'Br': 194,
                'I': 214},
          'N': {'H': 101, 'C': 147, 'N': 145, 'O': 140, 'F': 136,
                'Cl': 175, 'Br': 214, 'S': 168, 'I': 222, 'P': 177},
          'O': {'H': 96, 'C': 143, 'N': 140, 'O': 148, 'F': 142,
                'Br': 172, 'S': 151, 'P': 163, 'Si': 163, 'Cl': 164,
                'I': 194},
          'F': {'H': 92, 'C': 135, 'N': 136, 'O': 142, 'F': 142,
                'S': 158, 'Si': 160, 'Cl': 166, 'Br': 178, 'P': 156,
                'I': 187},
          'B': {'H':  119, 'Cl': 175},
          'Si': {'Si': 233, 'H': 148, 'C': 185, 'O': 163, 'S': 200,
                 'F': 160, 'Cl': 202, 'Br': 215, 'I': 243 },
          'Cl': {'Cl': 199, 'H': 127, 'C': 177, 'N': 175, 'O': 164,
                 'P': 203, 'S': 207, 'B': 175, 'Si': 202, 'F': 166,
                 'Br': 214},
          'S': {'H': 134, 'C': 182, 'N': 168, 'O': 151, 'S': 204,
                'F': 158, 'Cl': 207, 'Br': 225, 'Si': 200, 'P': 210,
                'I': 234},
          'Br': {'Br': 228, 'H': 141, 'C': 194, 'O': 172, 'N': 214,
                 'Si': 215, 'S': 225, 'F': 178, 'Cl': 214, 'P': 222},
          'P': {'P': 221, 'H': 144, 'C': 184, 'O': 163, 'Cl': 203,
                'S': 210, 'F': 156, 'N': 177, 'Br': 222},
          'I': {'H': 161, 'C': 214, 'Si': 243, 'N': 222, 'O': 194,
                'S': 234, 'F': 187, 'I': 266},
          'As': {'H': 152}
          }

bonds2 = {'C': {'C': 134, 'N': 129, 'O': 120, 'S': 160},
          'N': {'C': 129, 'N': 125, 'O': 121},
          'O': {'C': 120, 'N': 121, 'O': 121, 'P': 150},
          'P': {'O': 150, 'S': 186},
          'S': {'P': 186}}


bonds3 = {'C': {'C': 120, 'N': 116, 'O': 113},
          'N': {'C': 116, 'N': 110},
          'O': {'C': 113}}


def print_table(bonds_dict):
    letters = ['H', 'C', 'O', 'N', 'P', 'S', 'F', 'Si', 'Cl', 'Br', 'I']

    new_letters = []
    for key in (letters + list(bonds_dict.keys())):
        if key in bonds_dict.keys():
            if key not in new_letters:
                new_letters.append(key)

    letters = new_letters

    for j, y in enumerate(letters):
        if j == 0:
            for x in letters:
                print(f'{x} & ', end='')
            print()
        for i, x in enumerate(letters):
            if i == 0:
                print(f'{y} & ', end='')
            if x in bonds_dict[y]:
                print(f'{bonds_dict[y][x]} & ', end='')
            else:
                print('- & ', end='')
        print()


# print_table(bonds3)


def check_consistency_bond_dictionaries():
    for bonds_dict in [bonds1, bonds2, bonds3]:
        for atom1 in bonds1:
            for atom2 in bonds_dict[atom1]:
                bond = bonds_dict[atom1][atom2]
                try:
                    bond_check = bonds_dict[atom2][atom1]
                except KeyError:
                    raise ValueError('Not in dict ' + str((atom1, atom2)))

                assert bond == bond_check, (
                    f'{bond} != {bond_check} for {atom1}, {atom2}')


stdv = {'H': 5, 'C': 1, 'N': 1, 'O': 2, 'F': 3}
margin1, margin2, margin3 = 10, 5, 3

allowed_bonds = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'B': 3, 'Al': 3,
                 'Si': 4, 'P': [3, 5],
                 'S': 4, 'Cl': 1, 'As': 3, 'Br': 1, 'I': 1, 'Hg': [1, 2],
                 'Bi': [3, 5]}


def get_bond_order(atom1, atom2, distance, check_exists=False):
    distance = 100 * distance  # We change the metric

    # Check exists for large molecules where some atom pairs do not have a
    # typical bond length.
    if check_exists:
        if atom1 not in bonds1:
            return 0
        if atom2 not in bonds1[atom1]:
            return 0

    # margin1, margin2 and margin3 have been tuned to maximize the stability of
    # the QM9 true samples.
    if distance < bonds1[atom1][atom2] + margin1:

        # Check if atoms in bonds2 dictionary.
        if atom1 in bonds2 and atom2 in bonds2[atom1]:
            thr_bond2 = bonds2[atom1][atom2] + margin2
            if distance < thr_bond2:
                if atom1 in bonds3 and atom2 in bonds3[atom1]:
                    thr_bond3 = bonds3[atom1][atom2] + margin3
                    if distance < thr_bond3:
                        return 3        # Triple
                return 2            # Double
        return 1                # Single
    return 0                    # No bond


def single_bond_only(threshold, length, margin1=5):
    if length < threshold + margin1:
        return 1
    return 0


def geom_predictor(p, l, margin1=5, limit_bonds_to_one=False):
    """ p: atom pair (couple of str)
        l: bond length (float)"""
    bond_order = get_bond_order(p[0], p[1], l, check_exists=True)

    # If limit_bonds_to_one is enabled, every bond type will return 1.
    if limit_bonds_to_one:
        return 1 if bond_order > 0 else 0
    else:
        return bond_order

    # l = l * 100  # to Angstrom.
    # l = l - 50     # The histograms are shifted by 50
    #
    # if p == ('B', 'C'):
    #     return single_bond_only(115, l)
    # if p == ('B', 'O'):
    #     return single_bond_only(145, l)
    # if p == ('Br', 'Br'):
    #     return single_bond_only(264, l)
    # if p == ('C', 'Bi'):
    #     return single_bond_only(237, l)
    # if p == ('C', 'Br'):
    #     return single_bond_only(149, l)
    # if p == ('C', 'C'):
    #     if l < 75:
    #         return 3
    #     if l < 84.5:
    #         return 2
    #     if l < 93.5:
    #         return 4
    #     if l < 115 + margin1:
    #         return 1
    #     return 0
    # if p == ('C', 'Cl'):
    #     return single_bond_only(165, l)
    # if p == ('C', 'F'):
    #     return single_bond_only(95, l)
    # if p == ('C', 'I'):
    #     return single_bond_only(165, l)
    # if p == ('C', 'N'):
    #     if l < 66.5:
    #         return 3
    #     if l < 77.5:
    #         return 2
    #     if l < 83.5:
    #         return 4
    #     if l < 126 + margin1:
    #         return 1
    #     return 0
    # if p == ('C', 'O'):
    #     if l < 75.5:
    #         return 2
    #     if l < 125 + margin1:
    #         return 1
    #     return 0
    # if p == ('C', 'P'):
    #     if l < 124.5:
    #         return 2
    #     if l < 135 + margin1:
    #         return 1
    #     return 0
    # if p == ('C', 'S'):
    #     if l < 118.5:
    #         return 2
    #     if l < 126.5:
    #         return 4
    #     if l < 170 + margin1:
    #         return 1
    #     return 0
    # if p == ('C', 'Si'):
    #     return single_bond_only(143, l)
    # if p == ('F', 'P'):
    #     return single_bond_only(112, l)
    # if p == ('F', 'S'):
    #     return single_bond_only(115, l)
    # if p == ('H', 'C'):
    #     return single_bond_only(68, l)
    # if p == ('H', 'F'):
    #     return single_bond_only(48, l)
    # if p == ('H', 'N'):
    #     return single_bond_only(68, l)
    # if p == ('H', 'O'):
    #     return single_bond_only(66, l)
    # if p == ('H', 'S'):
    #     return single_bond_only(102, l)
    # if p == ('I', 'I'):
    #     return single_bond_only(267, l)
    # if p == ('N', 'Cl'):
    #     return single_bond_only(122, l)
    # if p == ('N', 'N'):
    #     if l < 65:
    #         return 3
    #     if l < 69.5:
    #         return 1
    #     if l < 72.5:
    #         return 2
    #     if l < 85.5:
    #         return 4
    #     if l < 105 + margin1:
    #         return 1
    #     return 0
    # if p == ('N', 'O'):
    #     if l < 70.5:
    #         return 2
    #     if l < 77:
    #         return 1
    #     if l < 86.5:
    #         return 4
    #     if l < 100 + margin1:
    #         return 1
    #     return 0
    # if p == ('N', 'P'):
    #     if l < 111.5:
    #         return 2
    #     if l < 135 + margin1:
    #         return 1
    #     return 0
    # if p == ('N', 'S'):
    #     if l < 104.5:
    #         return 2
    #     if l < 107.5:
    #         return 1
    #     if l < 110.5:
    #         return 4
    #     if l < 111.5:
    #         return 2
    #     if l < 166 + margin1:
    #         return 1
    #     return 0
    # if p == ('O', 'Bi'):
    #     return single_bond_only(159, l)
    # if p == ('O', 'I'):
    #     return single_bond_only(152, l)
    # if p == ('O', 'O'):
    #     return single_bond_only(93, l)
    # if p == ('O', 'P'):
    #     if l < 102:
    #         return 2
    #     if l < 130 + margin1:
    #         return 1
    #     return 0
    # if p == ('O', 'S'):
    #     if l < 95.5:
    #         return 2
    #     if l < 170 + margin1:
    #         return 1
    #     return 0
    # if p == ('O', 'Si'):
    #     if l < 110.5:
    #         return 2
    #     if l < 115 + margin1:
    #         return 1
    #     return 0
    # if p == ('P', 'S'):
    #     if l < 154:
    #         return 2
    #     if l < 167 + margin1:
    #         return 1
    #     return 0
    # if p == ('S', 'S'):
    #     if l < 153.5:
    #         return 1
    #     if l < 154.5:
    #         return 4
    #     if l < 158.5:
    #         return 1
    #     if l < 162.5:
    #         return 2
    #     if l < 215 + margin1:
    #         return 1
    #     return 0
    # if p == ('Si', 'Si'):
    #     return single_bond_only(249, l)
    # return 0
