import rdkit.Chem as Chem
from rdkit.Chem import AllChem
import networkx as nx
import torch
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional, Union
from collections import defaultdict
from torch_geometric.data import Data
import rdkit.Chem.Draw as Draw
from PIL import Image
from guacamol.utils.chemistry import canonicalize
from tensorboardX import SummaryWriter
import json
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Descriptors import MolWt, MolLogP, NumHDonors, NumHAcceptors, TPSA

EPSILON = 1e-7

def smiles2mol(smiles: str, sanitize: bool=False) -> Chem.rdchem.Mol:
    if sanitize:
        return Chem.MolFromSmiles(smiles)
    mol = Chem.MolFromSmiles(smiles, sanitize=False)
    AllChem.SanitizeMol(mol, sanitizeOps=0)
    return mol

def graph2smiles(fragment_graph: nx.Graph, with_idx: bool=False) -> str:
    motif = Chem.RWMol()
    node2idx = {}
    for node in fragment_graph.nodes:
        idx = motif.AddAtom(smarts2atom(fragment_graph.nodes[node]['smarts']))
        if with_idx and fragment_graph.nodes[node]['smarts'] == '*':
            motif.GetAtomWithIdx(idx).SetIsotope(node)
        node2idx[node] = idx
    for node1, node2 in fragment_graph.edges:
        motif.AddBond(node2idx[node1], node2idx[node2], fragment_graph[node1][node2]['bondtype'])
    return Chem.MolToSmiles(motif, allBondsExplicit=True)

def networkx2data(G: nx.Graph) -> Tuple[Data, Dict[int, int]]:
    r"""Converts a :obj:`networkx.Graph`  to a
    :class:`torch_geometric.data.Data` instance, and the index mapping.
    """

    num_nodes = G.number_of_nodes()
    mapping = dict(zip(G.nodes(), range(num_nodes)))
    
    G = nx.relabel_nodes(G, mapping)
    G = G.to_directed() if not nx.is_directed(G) else G

    edges = list(G.edges)
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

    x = torch.tensor([i for _, i in G.nodes(data='label')])
    edge_attr = torch.tensor([[i] for _, _, i in G.edges(data='label')], dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    return data, mapping



def smarts2atom(smarts: str) -> Chem.rdchem.Atom:
    return Chem.MolFromSmarts(smarts).GetAtomWithIdx(0)


def mol_graph2smiles(graph: nx.Graph, postprocessing: bool=True) -> str:
    mol = Chem.RWMol()
    graph = nx.convert_node_labels_to_integers(graph)
    node2idx = {}
    for node in graph.nodes:
        idx = mol.AddAtom(smarts2atom(graph.nodes[node]['smarts']))
        node2idx[node] = idx
    for node1, node2 in graph.edges:
        mol.AddBond(node2idx[node1], node2idx[node2], graph[node1][node2]['bondtype'])
    mol = mol.GetMol()
    smiles = Chem.MolToSmiles(mol)
    return regen_smiles(smiles) if postprocessing else smiles
    
def regen_smiles(smiles: str) -> str:
    try:
        mol = Chem.MolFromSmiles(smiles)
        return Chem.MolToSmiles(mol)
    except:
        mol = Chem.MolFromSmiles(smiles, sanitize=False)
        
        for atom in mol.GetAtoms():
            if atom.GetIsAromatic() and not atom.IsInRing():
                atom.SetIsAromatic(False)
        
        for bond in mol.GetBonds():
            if bond.GetBondType() == Chem.rdchem.BondType.AROMATIC:
                if not (bond.GetBeginAtom().GetIsAromatic() and bond.GetEndAtom().GetIsAromatic()):
                    bond.SetBondType(Chem.rdchem.BondType.SINGLE)
        
        for _ in range(100):
            problems = Chem.DetectChemistryProblems(mol)
            flag = False
            for problem in problems:
                if problem.GetType() =='KekulizeException':
                    flag = True
                    for atom_idx in problem.GetAtomIndices():
                        mol.GetAtomWithIdx(atom_idx).SetIsAromatic(False)
                    for bond in mol.GetBonds():
                        if bond.GetBondType() == Chem.rdchem.BondType.AROMATIC:
                            if not (bond.GetBeginAtom().GetIsAromatic() and bond.GetEndAtom().GetIsAromatic()):
                                bond.SetBondType(Chem.rdchem.BondType.SINGLE)
            mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol), sanitize=False)
            if flag: continue
            else: break
        
        smi = Chem.MolToSmiles(mol)
        mol = Chem.MolFromSmiles(smi, sanitize=False)
        try:
            Chem.SanitizeMol(mol)
        except:
            print(f"{smiles} not valid")
            return "CC"
        smi = Chem.MolToSmiles(mol)
        return smi

def get_conn_list(motif: Chem.rdchem.Mol, use_Isotope: bool=False, symm: bool=False) -> Tuple[List[int], Dict[int, int]]:
    '''
    Input a motif with connection sites, return the list of connection sites, and the ordermap.
    If with_idx==True: return the Isotope numbers, i.e., the indices in the full molecules.
    If with_idx==False: return the atom indices, i.e., the indices in the motifs.
    If symm==True: considering the symmetry issue.
    '''
    ranks = list(Chem.CanonicalRankAtoms(motif, includeIsotopes=False, breakTies=False))
    if use_Isotope:
        ordermap = {atom.GetIsotope(): ranks[atom.GetIdx()] for atom in motif.GetAtoms() if atom.GetSymbol() == '*'}
    else:
        ordermap = {atom.GetIdx(): ranks[atom.GetIdx()] for atom in motif.GetAtoms() if atom.GetSymbol() == '*'}
    if len(ordermap) == 0:
        return [], {}
    ordermap = dict(sorted(ordermap.items(), key=lambda x: x[1]))
    if not symm:
        conn_atoms = list(ordermap.keys())
    else:
        cur_order, conn_atoms = -1, []
        for idx, order in ordermap.items():
            if order != cur_order:
                cur_order = order
                conn_atoms.append(idx)
    return conn_atoms, ordermap


def label_attachment(smiles: str) -> str:
    '''
    label the attachment atoms with their order as isotope (considering the symmetry issue)
    '''
    mol = Chem.MolFromSmiles(smiles)
    ranks = list(Chem.CanonicalRankAtoms(mol, breakTies=False))
    dummy_atoms = [(atom.GetIdx(), ranks[atom.GetIdx()])for atom in mol.GetAtoms() if atom.GetSymbol() == '*']
    dummy_atoms.sort(key=lambda x: x[1])
    orders = []
    for (idx, order) in dummy_atoms:
        if order not in orders:
            orders.append(order)
            mol.GetAtomWithIdx(idx).SetIsotope(len(orders))
    return Chem.MolToSmiles(mol)



def get_rec_acc(smiles: List[str], gen_smiles: List[str]) -> float:
    num = len(smiles)
    return sum([smiles[i] == gen_smiles[i] for i in range(num)]) / num

def get_accuracy(scores: torch.Tensor, labels: torch.Tensor):
    _, preds = torch.max(scores, dim=-1)
    acc = torch.eq(preds, labels).float()

    number, indices = torch.topk(scores, k=10, dim=-1)
    topk_acc = torch.eq(indices, labels.view(-1,1)).float()
    return torch.sum(acc) / labels.nelement(), torch.sum(topk_acc) / labels.nelement()

def sample_from_distribution(distribution: torch.Tensor, greedy: bool=False, topk: int=0):
    if greedy or topk == 1:
        motif_indices = torch.argmax(distribution, dim=-1)
    elif topk == 0 or len(torch.where(distribution > 0)) <= topk:
        motif_indices = torch.multinomial(distribution, 1)
    else:
        _, topk_idx = torch.topk(distribution, topk, dim=-1)
        mask = torch.zeros_like(distribution)
        ones = torch.ones_like(distribution)
        mask.scatter_(-1, topk_idx, ones)
        motif_indices = torch.multinomial(distribution * mask, 1)
    return motif_indices

def get_accuracy_bin(scores: torch.Tensor, labels: torch.Tensor):
    preds = torch.ge(scores, 0).long()
    acc = torch.eq(preds, labels).float()
    
    tp = (labels * preds).sum().to(torch.float32)
    fp = ((1 - labels) * preds).sum().to(torch.float32)
    fn = (labels * (1 - preds)).sum().to(torch.float32)    
    
    precision = tp / (tp + fp + EPSILON)
    recall = tp / (tp + fn + EPSILON)
    
    f1 = 2 * (precision * recall) / (precision + recall + EPSILON)
    
    return torch.sum(acc) / labels.nelement(), f1, precision, recall

def motif_no_dummy(smiles):
    mol = Chem.MolFromSmiles(smiles, sanitize=False)
    AllChem.SanitizeMol(mol, sanitizeOps=0)
    atom_indices = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetSymbol() != '*']
    if len(atom_indices) == 1:
        atom = mol.GetAtomWithIdx(atom_indices[0])
        return atom.GetSmarts()
    smi = Chem.MolFragmentToSmiles(mol, tuple(atom_indices))
    return Chem.MolToSmiles(Chem.MolFromSmiles(smi, sanitize=False))

def fragment_to_idx(fragment_smiles):
    for atom in Chem.MolFromSmiles(fragment_smiles, sanitize=False).GetAtoms():
        if atom.GetSymbol() == '*': return atom.GetIsotope()

def get_dummy_atoms(mol_graph: nx.Graph) -> List[Tuple[int, int]]:
    return [node for node in mol_graph.nodes if mol_graph.nodes[node]['smarts'] == '*']

def det_dummy_atoms_from_smiles(smiles: str) -> List[int]:
    mol = Chem.MolFromSmiles(smiles)
    return [atom.GetIdx() for atom in mol.GetAtoms()]