import torch
import re
from rdkit import Chem, RDLogger
from rdkit.Geometry import Point3D

import torch.nn.functional as F

lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

allowed_bonds = {'H': {0: 1, 1: 0, -1: 0},
                 'C': {0: [3, 4], 1: 3, -1: 3},
                 'N': {0: [2, 3], 1: [2, 3, 4], -1: 2},    # In QM9, N+ seems to be present in the form NH+ and NH2+
                 'O': {0: 2, 1: 3, -1: 1},
                 'F': {0: 1, -1: 0},
                 'B': 3, 'Al': 3, 'Si': 4,
                 'P': {0: [3, 5], 1: 4},
                 'S': {0: [2, 6], 1: [2, 3], 2: 4, 3: 5, -1: 3},
                 'Cl': 1, 'As': 3,
                 'Br': {0: 1, 1: 2}, 'I': 1, 'Hg': [1, 2], 'Bi': [3, 5], 'Se': [2, 4, 6]}
bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE,
             Chem.rdchem.BondType.AROMATIC]
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}


class Molecule:
    def __init__(self, node_types, edge_types, positions, charges, atom_decoder,
                 use_charges, use_3d, charges_policy='no'):
        """ node_types: n      LongTensor
            charges: n         LongTensor
            edge_types: n x n  LongTensor
            positions: n x 3   FloatTensor
            atom_decoder: extracted from dataset_infos. """
        assert node_types.dim() == 1 and node_types.dtype == torch.long, f"shape of atoms {node_types.shape} " \
                                                                         f"and dtype {node_types.dtype}"
        assert edge_types.dim() == 2 and edge_types.dtype == torch.long, f"shape of bonds {edge_types.shape} --" \
                                                                         f" {edge_types.dtype}"
        assert len(node_types.shape) == 1
        assert len(edge_types.shape) == 2

        self.use_charges = use_charges
        self.use_3d = use_3d

        if(self.use_3d): assert len(positions.shape) == 2

        
        self.charges_policy = charges_policy
        self.node_types = node_types.long()
        self.edge_types = edge_types.long()
        self.positions = positions
        self.charges = charges
        self.rdkit_mol = self.build_molecule(atom_decoder)
        self.num_nodes = len(node_types)
        self.num_node_types = len(atom_decoder)

    def build_molecule(self, atom_decoder, verbose=False):
        """ If positions is None,
        """
        if verbose:
            print("building new molecule")

        mol = Chem.RWMol()
        if(self.use_charges == False):
            self.charges = torch.zeros_like(self.node_types)
        
        for atom, charge in zip(self.node_types, self.charges):
            if atom == -1:
                continue
            atom_symbol = atom_decoder[int(atom.item())]

            formal_charge = 0
            if self.use_charges:
                if charge.item() != 0:
                    formal_charge = charge.item()
            elif self.charges_policy == 'dictionary':
                negative_charge = atom_symbol.find("-1")
                positive_charge = atom_symbol.find("+1")

                #if the atom is not neutrally charged:
                if(negative_charge != -1 or positive_charge != -1):
                    #TODO: this may require a more robust handling if we
                    #will add more infos in the atom_decoder strings
                    formal_charge = int(atom_symbol[-2:])   #the formal charge are the last two characters
                    atom_symbol = atom_symbol[:-2]          #the string before the last 2 characters is the actual atomic symbol
                                                            #(we MUST do this after getting formal_charge)
            
            a = Chem.Atom(atom_symbol)
            a.SetFormalCharge(formal_charge)
            mol.AddAtom(a)
            if verbose:
                print("Atom added: ", atom.item(), atom_decoder[atom.item()])

        edge_types = torch.triu(self.edge_types, diagonal=1)
        edge_types[edge_types == -1] = 0
        all_bonds = torch.nonzero(edge_types)
        for i, bond in enumerate(all_bonds):
            if bond[0].item() != bond[1].item():
                mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()])
                if verbose:
                    print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(),
                          bond_dict[edge_types[bond[0], bond[1]].item()])
                
                # NOTE: tested ONLY with QM9!
                if self.charges_policy == 'partial':
                    # add formal charge to atom: e.g. [O+], [N+], [S+]
                    # not support [O-], [N-], [S-], [NH+] etc.
                    flag, atomid_valence = check_valency(mol)
                    if verbose:
                        print("flag, valence", flag, atomid_valence)
                    if flag:
                        continue
                    else:
                        assert len(atomid_valence) == 2
                        idx = atomid_valence[0]
                        v = atomid_valence[1]
                        an = mol.GetAtomWithIdx(idx).GetAtomicNum()
                        if verbose:
                            print("atomic num of atom with a large valence", an)
                        if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
                            mol.GetAtomWithIdx(idx).SetFormalCharge(1)

        try:
            mol = mol.GetMol()
        except Chem.KekulizeException:
            print("Can't kekulize molecule")
            return None

        if(self.use_3d):
            # Set coordinates
            positions = self.positions.double()
            conf = Chem.Conformer(mol.GetNumAtoms())
            for i in range(mol.GetNumAtoms()):
                conf.SetAtomPosition(i, Point3D(positions[i][0].item(), positions[i][1].item(), positions[i][2].item()))
            mol.AddConformer(conf)

        return mol


# Functions from GDSS
def check_valency(mol):
    try:
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
        return True, None
    except ValueError as e:
        e = str(e)
        p = e.find('#')
        e_sub = e[p:]
        atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
        return False, atomid_valence


def check_stability(molecule, dataset_info, debug=False, atom_decoder=None, smiles=None):
    """ molecule: Molecule object. """
    device = molecule.node_types.device
    if atom_decoder is None:
        atom_decoder = dataset_info.atom_decoder

    node_types = molecule.node_types
    edge_types = molecule.edge_types

    edge_types[edge_types == 4] = 1.5
    edge_types[edge_types < 0] = 0

    valencies = torch.sum(edge_types, dim=-1).long()

    use_charges = dataset_info.cfg.features.use_charges

    if(not use_charges):
        molecule.charges = [0 for x in node_types]

    n_stable_bonds = 0
    mol_stable = True
    for i, (atom_type, valency, charge) in enumerate(zip(node_types, valencies, molecule.charges)):
        atom_type = atom_type.item()
        valency = valency.item()
        charge = charge.item()
        possible_bonds = allowed_bonds[atom_decoder[atom_type]]
        if type(possible_bonds) == int:
            is_stable = possible_bonds == valency
        elif type(possible_bonds) == dict:
            expected_bonds = possible_bonds[charge] if charge in possible_bonds.keys() else possible_bonds[0]
            is_stable = expected_bonds == valency if type(expected_bonds) == int else valency in expected_bonds
        else:
            is_stable = valency in possible_bonds
        if not is_stable:
            mol_stable = False
        if not is_stable and debug:
            if smiles is not None:
                print(smiles)
            print(f"Invalid atom {atom_decoder[atom_type]}: valency={valency}, charge={charge}")
            print()
        n_stable_bonds += int(is_stable)

    return torch.tensor([mol_stable], dtype=torch.float, device=device),\
           torch.tensor([n_stable_bonds], dtype=torch.float, device=device),\
           len(node_types)

def make_molecular_list(sampled, chains, batch_size, atom_decoder, 
                        keep_chain, n_nodes, use_charges, use_3d, charges_policy):
    X, charges, E, y, pos = sampled.X, sampled.charges, sampled.E, sampled.y, sampled.pos

    if chains is not None:
        pad_size = chains.X.size(-1) - X.size(-1)
        chain_X = F.pad(X[:keep_chain], (0,pad_size,0,0))
        chain_E = F.pad(E[:keep_chain], (0,pad_size,0,pad_size, 0, 0))
        chains.X[-1] = chain_X  # Overwrite last frame with the resulting X, E
        chains.E[-1] = chain_E

        if(use_charges): 
            chain_charges = F.pad(charges[:keep_chain], (0,pad_size,0,0))
            chains.charges[-1] = chain_charges
        if(use_3d):      
            pos_charges = F.pad(pos[:keep_chain], (0,0,0,pad_size,0,0))
            chains.pos[-1] = pos_charges[:keep_chain]

    #TODO: generalize this part.
    graphs_list = []
    for i in range(batch_size):
        n = n_nodes[i]
        node_types = X[i, :n]
        edge_types = E[i, :n, :n]
        charge_vec = charges[i, :n] if use_charges else None
        conformer = pos[i, :n]      if use_3d else None
        graphs_list.append(Molecule(node_types=node_types, charges=charge_vec,
                                    edge_types=edge_types, positions=conformer,
                                    atom_decoder=atom_decoder,
                                    use_charges=use_charges, use_3d=use_3d,
                                    charges_policy=charges_policy))
        
    return graphs_list

def mol2smiles(mol):
    try:
        Chem.SanitizeMol(mol)
    except ValueError:
        return None
    return Chem.MolToSmiles(mol)