import numpy as np
import rdkit
from rdkit import Chem
from rdkit.Chem import rdqueries


def get_atom_map_set(smi):
    """
    :param smi: SMILES string.
    :return: atom_map: set of atom map numbers.
    """
    atom_map = set()
    m = Chem.MolFromSmiles(smi)
    for a in m.GetAtoms():
        atom_map.add(a.GetAtomMapNum())
    atom_map -= {0}

    return atom_map

def polymerization_is_valid(reactants, num_atoms):
    """
    Check the validity of step growth polymerization.
    Atoms labeled 1,2,...,N should appear in reactant A.
    Atoms labeled N+1,N+1,...,2N should appear in reactant B.

    :param reactants: SMILES string of reactants.
    :param num_atoms: number of (non-H) atoms in the repeat unit.
    :return: is_valid (True/False).
    """
    set_a = set(np.arange(num_atoms) + 1)
    set_b = set(np.arange(num_atoms) + 1 + num_atoms)

    reactants = list(set(reactants.split('.')))

    # if len(reactants) == 1:
    #     atom_map = get_atom_map_set(reactants[0])
    #
    #     if set_a.issubset(atom_map) and len(set_b.intersection(atom_map)) == 0:
    #         return True
    #     if set_b.issubset(atom_map) and len(set_a.intersection(atom_map)) == 0:
    #         return True

    if len(reactants) == 2:
        atom_map_a = get_atom_map_set(reactants[0])
        atom_map_b = get_atom_map_set(reactants[1])

        if set_a.issubset(atom_map_a) and set_b.issubset(atom_map_b):
            return True
        if set_a.issubset(atom_map_b) and set_b.issubset(atom_map_a):
            return True

    return False

def get_neighbor_atom(atom):
    """
    Get the ONLY neighbor of an atom.

    :param atom: query atom.
    :return: neighbor: neighboring atom\
    """
    neighbors = atom.GetNeighbors()
    assert len(neighbors) == 1

    return neighbors[0]

def increase_H(atom, times=1):
    for i in range(times):
        atom.SetNumExplicitHs(atom.GetNumExplicitHs() + 1)

def reduce_H(atom, times=1):
    for i in range(times):
        atom.SetNumExplicitHs(max(0, atom.GetNumExplicitHs() - 1))

def form_double_units(repeat_unit):
    """
    Form double (repeat) units by concatenating two repeat units and replacing
    the two open bonds ([At]) with [H]. Set atom map numbers for existing
    (non-H) atoms.

    :param repeat_unit: repeat unit of the target chain polymer. For example,
            c1([At])ccc(cc1)C=C([At]), with the two [At] representing the bonds
            connecting neighboring repeat units.
    :return: double_unit: corresponding duoble unit.
             double_idxs: markings for the double unit.
    """
    AT = 85
    at_query = rdqueries.AtomNumEqualsQueryAtom(AT)

    # assign atom map numbers
    m1 = Chem.MolFromSmiles(repeat_unit)
    atoms = m1.GetAtoms()
    num_atom = len(atoms) - 2
    idx = 1
    for atom in atoms:
        if atom.GetAtomicNum() != AT:
            atom.SetAtomMapNum(idx)
            idx += 1
    assert idx == num_atom + 1

    m2 = Chem.MolFromSmiles(repeat_unit)
    atoms = m2.GetAtoms()
    for atom in atoms:
        if atom.GetAtomicNum() != AT:
            atom.SetAtomMapNum(idx)
            idx += 1
    assert idx == 2 * num_atom + 1

    # combine two mols
    dm = Chem.CombineMols(m1, m2)
    at_atoms = dm.GetAtomsMatchingQuery(at_query)
    assert len(at_atoms) == 4
    assert at_atoms[3].GetIdx() > at_atoms[2].GetIdx() > \
           at_atoms[1].GetIdx() > at_atoms[0].GetIdx()

    mw = Chem.RWMol(dm)

    # modify the combined mol
    neighbor_atoms = []
    double_idxs = []
    for i in range(4):
        neighbor = get_neighbor_atom(at_atoms[i])
        neighbor_atoms.append(neighbor)
        double_idxs.append(neighbor.GetAtomMapNum())

    increase_H(neighbor_atoms[0])
    increase_H(neighbor_atoms[3])
    mw.AddBond(neighbor_atoms[1].GetIdx(), neighbor_atoms[2].GetIdx(),
               Chem.BondType.SINGLE)

    for i in range(4):
        mw.RemoveAtom(at_atoms[3-i].GetIdx())

    at_atoms = mw.GetAtomsMatchingQuery(at_query)
    assert len(at_atoms) == 0
    double_unit = Chem.MolToSmiles(mw)

    return double_unit, double_idxs

def get_atom_by_atom_map_num(mol, idx):
    assert idx != 0
    for atom in mol.GetAtoms():
        if atom.GetAtomMapNum() == idx:
            return atom
    return None

def modify_H(atom, bond_type=None, charge=0):
    if bond_type == Chem.BondType.SINGLE:
        reduce_H(atom)
    elif bond_type == Chem.BondType.DOUBLE:
        reduce_H(atom, times=2)

    if charge == 1:
        increase_H(atom)
    elif charge == -1:
        reduce_H(atom)

def remove_atom_map(smi):
    try:
        m = Chem.MolFromSmiles(smi)
        for a in m.GetAtoms():
            a.SetAtomMapNum(0)
        smi = Chem.MolToSmiles(m)
        return smi
    except Exception as e:
        return None

def produce_candidate_monomer(reactants, double_idxs):
    """
    Produce candidate monomer from reactants of the double unit.

    :param reactants: reactants of the double unit.
    :param double_idxs: markings for the double unit.
    :return: candidate_monomer
    """
    m = Chem.MolFromSmiles(reactants)
    mw = Chem.RWMol(m)

    atom_left = get_atom_by_atom_map_num(mw, double_idxs[0])
    atom_left_2 = get_atom_by_atom_map_num(mw, double_idxs[2])

    if atom_left is None or atom_left_2 is None:
        return None

    if atom_left_2.GetFormalCharge() != 0:
        charge = atom_left_2.GetFormalCharge()
        atom_left.SetFormalCharge(charge)
        modify_H(atom_left, charge=charge)

    left_end_group_start = None
    for neighbor in atom_left_2.GetNeighbors():
        if neighbor.GetAtomMapNum() == 0:
            left_end_group_start = neighbor
            break

    if left_end_group_start is not None:
        bond = mw.GetBondBetweenAtoms(left_end_group_start.GetIdx(),
                                      atom_left_2.GetIdx())
        modify_H(atom_left, bond.GetBondType())
        mw.AddBond(atom_left.GetIdx(), left_end_group_start.GetIdx(),
                   bond.GetBondType())

        if bond.GetBondType() == Chem.BondType.AROMATIC:
            # can't deal with aromatic bonds
            return None

        mw.RemoveBond(atom_left_2.GetIdx(), left_end_group_start.GetIdx())

    try:
        mols = Chem.rdmolops.GetMolFrags(mw, asMols=True, sanitizeFrags=True)
    except Exception as e:
        return None

    for m in mols:
        if get_atom_by_atom_map_num(m, 1) is not None:
            smi = Chem.MolToSmiles(m)
            return remove_atom_map(smi)

    return None

def produce_double_candidate_monomer(reactants, double_idxs):
    """
    Produce double candidate monomer by concatenating two candidate monomers.

    :param reactants: reactants of the double unit.
    :param double_idxs: markings for the double unit.
    :return: double_candidate_monomer
    """
    m = Chem.MolFromSmiles(reactants)
    mw = Chem.RWMol(m)

    atom_left = get_atom_by_atom_map_num(mw, double_idxs[0])
    atom_right_2 = get_atom_by_atom_map_num(mw, double_idxs[3])

    if atom_left is None or atom_right_2 is None:
        return None

    reduce_H(atom_left)
    reduce_H(atom_right_2)
    mw.AddBond(atom_left.GetIdx(), atom_right_2.GetIdx(),
               Chem.BondType.SINGLE)
    smi = Chem.MolToSmiles(mw)

    return remove_atom_map(smi)
