from typing import Dict, List, Tuple, Optional, Union
import re

from rdkit import Chem

from torch import Tensor
from src.datatypes.sparse import SparseGraph

from src.datatypes.sparse import to_directed

from rdkit.Chem.rdchem import BondType as BT

BOND_TYPES_REAL = {1: BT.SINGLE, 2: BT.DOUBLE, 3: BT.TRIPLE}
BOND_TYPES = {bt: str(bt) for bt in [BT.SINGLE, BT.DOUBLE, BT.TRIPLE, BT.AROMATIC]}
BOND_TYPES_REV = {str(bt): bt for bt in [BT.SINGLE, BT.DOUBLE, BT.TRIPLE, BT.AROMATIC]}
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}

def mol2smiles(mol, sanitize=True):
    if sanitize:
        try:
            Chem.SanitizeMol(mol)
        except ValueError:
            return None
    return Chem.MolToSmiles(mol, canonical=True)


def build_molecule(
        atom_types: Tensor,
        edge_index: Tensor,
        edge_types: Tensor,
        atom_decoder: Union[Dict[int, str], Dict[str, int]],
        bond_decoder: Optional[Dict[int, str]]=None,
        relaxed: bool=False,
        verbose: bool=False
    ):
    if verbose:
        print("building new molecule")

    atom_decoder, bond_decoder = check_atom_bond_decoders(
        atom_decoder,
        bond_decoder
    )

    ###############################  PARSE ATOMS  ##############################
    mol = Chem.RWMol()
    for atom in atom_types:
        a = Chem.Atom(atom_decoder[atom.item()])
        mol.AddAtom(a)
        if verbose:
            print("Atom added: ", atom.item(), atom_decoder[atom.item()])

    ###############################  PARSE BONDS  ##############################

    for bond, link in zip(edge_types, edge_index.permute(1, 0)):

        if link[0].item() != link[1].item():
            mol.AddBond(link[0].item(), link[1].item(), BOND_TYPES_REV[bond_decoder[bond.item()]])
            if verbose:
                print(
                    "bond added:", link[0].item(), link[1].item(), bond.item(),
                      bond_decoder[bond.item()]
                )

            if relaxed:
                # add formal charge to atom: e.g. [O+], [N+], [S+]
                # not support [O-], [N-], [S-], [NH+] etc.
                flag, atomid_valence = check_valency(mol)
                if verbose:
                    print("flag, valence", flag, atomid_valence)
                if flag:
                    continue
                else:
                    assert len(atomid_valence) == 2
                    idx = atomid_valence[0]
                    v = atomid_valence[1]
                    an = mol.GetAtomWithIdx(idx).GetAtomicNum()
                    if verbose:
                        print("atomic num of atom with a large valence", an)
                    if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
                        mol.GetAtomWithIdx(idx).SetFormalCharge(1)
                        # print("Formal charge added")

    return mol

class GraphToMoleculeConverter:

    def __init__(
            self,
            atom_decoder: Union[Dict[int, str], Dict[str, int]],
            bond_decoder: Union[Dict[int, str], Dict[str, int]],
            relaxed: bool=False,
            post_hoc_mols_fix: bool=False
        ):

        self.atom_decoder, self.bond_decoder = check_atom_bond_decoders(
            atom_decoder,
            bond_decoder
        )
        self.relaxed = relaxed
        self.post_hoc_mols_fix = post_hoc_mols_fix


    def __call__(
            self,
            batch: Union[List[SparseGraph], SparseGraph],
            override_relaxed: Optional[bool]=None,
            override_post_hoc_mols_fix: Optional[bool]=None
        ) -> Union[List[Chem.Mol], Chem.Mol]:

        is_batch = True

        if not isinstance(batch, list):
            # check and format for batch
            is_batch = hasattr(batch, 'ptr')

            if is_batch:
                batch = batch.to_data_list()
            else:
                batch = [batch]

        # build molecules from graphs
        out_molecules = []

        g: SparseGraph
        for g in batch:

            g = g.clone()

            if g.is_undirected():
                g.edge_index, g.edge_attr = to_directed(g.edge_index, g.edge_attr)

            # collapse classes if needed
            g.collapse()

            mol = build_molecule(
                atom_types =		g.x,
                edge_index =		g.edge_index,
                edge_types =		g.edge_attr,
                atom_decoder =		self.atom_decoder,
                bond_decoder =		self.bond_decoder,
                relaxed =		    self.relaxed if override_relaxed is None else override_relaxed,
            )

            if override_post_hoc_mols_fix is None:
                if self.post_hoc_mols_fix:
                    mol, _ = correct_mol(mol)

            else:
                if override_post_hoc_mols_fix:
                    mol, _ = correct_mol(mol)

            out_molecules.append(mol)


        return out_molecules if is_batch else out_molecules[0]


def check_atom_bond_decoders(
        atom_decoder,
        bond_decoder
    ):
    # if atom_decoder is given as atom_name -> atom_idx
    # reverse mapping to atom_idx -> atom_name
    if len(atom_decoder) > 0 and isinstance(next(iter(atom_decoder.keys())), str):
        atom_decoder = {v: k for k, v in atom_decoder.items()}
    
    if len(bond_decoder) > 0 and isinstance(next(iter(bond_decoder.keys())), str):
        bond_decoder = {v: k for k, v in bond_decoder.items()}

    return atom_decoder, bond_decoder


# Functions from GDSS
def check_valency(mol):
    try:
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
        return True, None
    except ValueError as e:
        e = str(e)
        p = e.find('#')
        e_sub = e[p:]
        atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
        return False, atomid_valence
    


# code from: https://github.com/GRAPH-0/CDGS/blob/main/utils.py
def correct_mol(
        mol: Chem.Mol,
    ):

    no_correct = False
    flag, _ = check_valency(mol)
    if flag:
        no_correct = True

    while True:
        flag, atomid_valence = check_valency(mol)
        if flag:
            break
        else:
            assert len(atomid_valence) == 2
            idx = atomid_valence[0]
            queue = []

            for b in mol.GetAtomWithIdx(idx).GetBonds():
                queue.append(
                    (b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx())
                )
            queue.sort(key=lambda tup: tup[1], reverse=True)

            if len(queue) > 0:
                start = queue[0][2]
                end = queue[0][3]
                t = queue[0][1] - 1
                mol.RemoveBond(start, end)
                if t >= 1:
                    mol.AddBond(start, end, BOND_TYPES_REAL[t])

    return mol, no_correct