import os
from itertools import repeat
from typing import Callable

from scipy.stats import gaussian_kde
import pandas as pd
from tqdm import tqdm
from rdkit.Chem import AllChem
from rdkit.Chem.rdchem import Mol
from rdkit import Chem, DataStructs
from rdkit.Chem import rdFreeSASA, rdMolDescriptors
import numpy as np
import torch
from ogb.utils.mol import smiles2graph
from torch_geometric.data import InMemoryDataset, Data
from transformers import RobertaTokenizerFast
from rdkit.Avalon import pyAvalonTools
from rdkit.Chem import Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors
from mordred import Calculator, descriptors
from torch_geometric.transforms import AddLaplacianEigenvectorPE, AddRandomWalkPE
import pdb
from deepchem.feat.smiles_tokenizer import SmilesTokenizer
from torch.nn.utils.rnn import pad_sequence
import types
import re
from transformers import GPT2Config, GPT2LMHeadModel, PreTrainedTokenizerFast
from collections import deque, defaultdict
from utils.chem import mol_to_graphs
from transformers import AutoTokenizer, AutoModelForMaskedLM

import random

FINGERPRINT_SIZE = 1024

def dfs_torch_geometric(edge_index, start_node=0):
    # 获取所有节点的集合
    nodes = set(edge_index[0].tolist()) | set(edge_index[1].tolist())

    visited = set()
    order = []

    # 初始化栈，首先从 start_node 开始执行 DFS
    stack = deque([start_node])

    # 深度优先搜索
    while stack:
        current_node = stack.pop()
        if current_node not in visited:
            visited.add(current_node)
            order.append(current_node)

            # 查找当前节点的邻居
            neighbors = edge_index[1, edge_index[0] == current_node].tolist() + \
                        edge_index[0, edge_index[1] == current_node].tolist()

            # 将未访问的邻居压入栈
            for neighbor in reversed(sorted(neighbors)):  # 使用 sorted 保持一致性
                if neighbor not in visited:
                    stack.append(neighbor)
    # 处理不连续的部分
    for node in nodes:
        if node not in visited:
            stack = deque([node])
            while stack:
                current_node = stack.pop()
                if current_node not in visited:
                    visited.add(current_node)
                    order.append(current_node)

                    # 查找当前节点的邻居
                    neighbors = edge_index[1, edge_index[0] == current_node].tolist() + \
                                edge_index[0, edge_index[1] == current_node].tolist()

                    # 将未访问的邻居压入栈
                    for neighbor in reversed(sorted(neighbors)):
                        if neighbor not in visited:
                            stack.append(neighbor)
    return order


def bfs_torch_geometric(edge_index, start_node=0):
    # 获取所有节点的集合
    nodes = set(edge_index[0].tolist()) | set(edge_index[1].tolist())

    visited = set()
    order = []

    # 初始化队列，首先从 start_node 开始执行 BFS
    queue = deque([start_node])

    # 广度优先搜索
    while queue:
        current_node = queue.popleft()
        if current_node not in visited:
            visited.add(current_node)
            order.append(current_node)

            # 查找当前节点的邻居
            neighbors = edge_index[1, edge_index[0] == current_node].tolist() + \
                        edge_index[0, edge_index[1] == current_node].tolist()  # 包含双向的边

            # 将未访问的邻居加入队列
            for neighbor in neighbors:
                if neighbor not in visited:
                    queue.append(neighbor)

    # 处理不连续的部分
    for node in nodes:
        if node not in visited:
            queue = deque([node])
            while queue:
                current_node = queue.popleft()
                if current_node not in visited:
                    visited.add(current_node)
                    order.append(current_node)

                    # 查找当前节点的邻居
                    neighbors = edge_index[1, edge_index[0] == current_node].tolist() + \
                                edge_index[0, edge_index[1] == current_node].tolist()

                    # 将未访问的邻居加入队列
                    for neighbor in neighbors:
                        if neighbor not in visited:
                            queue.append(neighbor)

    return order

atoms = ['Al', 'As', 'B', 'Br', 'C', 'Cl', 'F', 'H', 'I', 'K', 'Li', 'N', 'Na', 'O', 'P', 'S', 'Se', 'Si', 'Te']
special = ['(', ')', '[', ']', '=', '#', '%', '0', '1', '2', '3', '4', '5',
        '6', '7', '8', '9', '+', '-', 'se', 'te', 'c', 'n', 'o', 'p', 's', '/', '\\']
tokens_list = atoms + special
char_to_index = {char: idx for idx, char in enumerate(tokens_list)}

def encode_smiles(smiles):
    sorted_chars = sorted(char_to_index.keys(), key=len, reverse=True)
    pattern = '|'.join(re.escape(char) for char in sorted_chars)
    tokens = re.split(f'({pattern})', smiles)
    tokens = [token for token in tokens if token]
    #for token in tokens:
    #    if token not in tokens_list:
    #        print(token)
    encoded = [char_to_index.get(token) for token in tokens]
    return encoded

def getmorganfingerprint(mol: Mol):
    """Get the ECCP fingerprint.

    Args:
        mol (Mol): The molecule.
    """
    return list(AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=FINGERPRINT_SIZE)), AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=FINGERPRINT_SIZE)


def getmaccsfingerprint(mol: Mol):
    """Get the MACCS fingerprint.

    Args:
        mol (Mol): The molecule.
    """
    fp = AllChem.GetMACCSKeysFingerprint(mol)
    return [int(b) for b in fp.ToBitString()]

def euclidean_distance(p1, p2):
        return np.sqrt(np.sum((p1 - p2) ** 2))


filtered_desc_names = [
        'MaxAbsEStateIndex', 'MaxEStateIndex', 'MinAbsEStateIndex', 'MinEStateIndex', 'qed', 
        'HeavyAtomMolWt', 'NumValenceElectrons', 'NumRadicalElectrons', 'MaxPartialCharge',
        'MinPartialCharge', 'MaxAbsPartialCharge', 'MinAbsPartialCharge', 'FpDensityMorgan1',
        'FpDensityMorgan2', 'FpDensityMorgan3', 'BCUT2D_MWHI', 'BCUT2D_MWLOW', 'BCUT2D_CHGHI',
        'BCUT2D_CHGLO', 'BCUT2D_LOGPHI', 'BCUT2D_LOGPLOW', 'BCUT2D_MRHI', 'BCUT2D_MRLOW',
        'BalabanJ', 'Chi0', 'Chi0n', 'Chi0v', 'Chi1', 'Chi1n', 'Chi1v', 'Chi2n', 'Chi2v',
        'Chi3n', 'Chi3v', 'Chi4n', 'Chi4v', 'HallKierAlpha', 'Kappa1', 'Kappa2', 'Kappa3',
        'LabuteASA', 'PEOE_VSA1', 'PEOE_VSA10', 'PEOE_VSA11', 'PEOE_VSA12', 'PEOE_VSA13',
        'PEOE_VSA14', 'PEOE_VSA2', 'PEOE_VSA3', 'PEOE_VSA4', 'PEOE_VSA5', 'PEOE_VSA6',
        'PEOE_VSA7', 'PEOE_VSA8', 'PEOE_VSA9', 'SMR_VSA1', 'SMR_VSA10', 'SMR_VSA2',
        'SMR_VSA3', 'SMR_VSA4', 'SMR_VSA5', 'SMR_VSA6', 'SMR_VSA7', 'SMR_VSA8', 'SMR_VSA9',
        'SlogP_VSA1', 'SlogP_VSA10', 'SlogP_VSA11', 'SlogP_VSA12', 'SlogP_VSA2', 'SlogP_VSA3',
        'SlogP_VSA4', 'SlogP_VSA5', 'SlogP_VSA6', 'SlogP_VSA7', 'SlogP_VSA8', 'SlogP_VSA9',
        'TPSA', 'EState_VSA1', 'EState_VSA10', 'EState_VSA11', 'EState_VSA2', 'EState_VSA3',
        'EState_VSA4', 'EState_VSA5', 'EState_VSA6', 'EState_VSA7', 'EState_VSA8', 'EState_VSA9',
        'VSA_EState1', 'VSA_EState10', 'VSA_EState2', 'VSA_EState3', 'VSA_EState4',
        'VSA_EState5', 'VSA_EState6', 'VSA_EState7', 'VSA_EState8', 'VSA_EState9',
        'FractionCSP3', 'HeavyAtomCount', 'NHOHCount', 'NOCount', 'NumAliphaticCarbocycles',
        'NumAliphaticHeterocycles', 'NumAliphaticRings', 'NumAromaticCarbocycles',
        'NumAromaticHeterocycles', 'NumAromaticRings',
        'NumHAcceptors', 'NumHDonors', 'NumHeteroatoms', 'NumRotatableBonds',
        'NumSaturatedCarbocycles', 'NumSaturatedHeterocycles', 'NumSaturatedRings',
        'RingCount', 'MolLogP', 'MolMR',
        'fr_Al_COO', 'fr_Al_OH', 'fr_Al_OH_noTert', 'fr_ArN', 'fr_Ar_COO', 'fr_Ar_N', 'fr_Ar_NH',
        'fr_Ar_OH', 'fr_COO', 'fr_COO2', 'fr_C_O', 'fr_C_O_noCOO', 'fr_C_S', 'fr_HOCCN', 'fr_Imine',
        'fr_NH0', 'fr_NH1', 'fr_NH2', 'fr_N_O', 'fr_Ndealkylation1', 'fr_Ndealkylation2',
        'fr_Nhpyrrole', 'fr_SH', 'fr_aldehyde', 'fr_alkyl_carbamate', 'fr_alkyl_halide',
        'fr_allylic_oxid', 'fr_amide', 'fr_amidine', 'fr_aniline', 'fr_aryl_methyl', 'fr_azide',
        'fr_azo', 'fr_barbitur', 'fr_benzene', 'fr_benzodiazepine', 'fr_bicyclic', 'fr_diazo',
        'fr_dihydropyridine', 'fr_epoxide', 'fr_ester', 'fr_ether', 'fr_furan', 'fr_guanido',
        'fr_halogen', 'fr_hdrzine', 'fr_hdrzone', 'fr_imidazole', 'fr_imide', 'fr_isocyan',
        'fr_isothiocyan', 'fr_ketone', 'fr_ketone_Topliss', 'fr_lactam', 'fr_lactone',
        'fr_methoxy', 'fr_morpholine', 'fr_nitrile', 'fr_nitro', 'fr_nitro_arom',
        'fr_nitro_arom_nonortho', 'fr_nitroso', 'fr_oxazole', 'fr_oxime', 'fr_para_hydroxylation',
        'fr_phenol', 'fr_phenol_noOrthoHbond', 'fr_phos_acid', 'fr_phos_ester', 'fr_piperdine',
        'fr_piperzine', 'fr_priamide', 'fr_prisulfonamd', 'fr_pyridine', 'fr_quatN', 'fr_sulfide',
        'fr_sulfonamd', 'fr_sulfone', 'fr_term_acetylene', 'fr_tetrazole', 'fr_thiazole',
        'fr_thiocyan', 'fr_thiophene', 'fr_unbrch_alkane', 'fr_urea'
        ]

def RDkit_descriptors(smiles):
    mols = [Chem.MolFromSmiles(i) for i in smiles] 
    calc = MoleculeDescriptors.MolecularDescriptorCalculator([x for x in filtered_desc_names])
    desc_names = calc.GetDescriptorNames()

    Mol_descriptors =[]
    for mol in mols:
        # add hydrogens to molecules
        #mol=Chem.AddHs(mol)
        # Calculate all 200 descriptors for each molecule
        descriptors = calc.CalcDescriptors(mol)
        Mol_descriptors.append(descriptors)
    return Mol_descriptors, desc_names 

class GaussianSmearing(torch.nn.Module):
    def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
        super(GaussianSmearing, self).__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
        self.register_buffer('offset', offset)

    def forward(self, dist):
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        return torch.exp(self.coeff * torch.pow(dist, 2))

def gen_conformers(mol, numConfs=100, maxAttempts=1000, pruneRmsThresh=0.1, useExpTorsionAnglePrefs=True, useBasicKnowledge=True, enforceChirality=True):
    ids = AllChem.EmbedMultipleConfs(mol, numConfs=numConfs, maxAttempts=maxAttempts, pruneRmsThresh=pruneRmsThresh, useExpTorsionAnglePrefs=useExpTorsionAnglePrefs, useBasicKnowledge=useBasicKnowledge, enforceChirality=enforceChirality, numThreads=0)
    return ids[0]

def canonical_smiles(smiles):
    mols = Chem.MolFromSmiles(smiles)
    smiles = Chem.MolToSmiles(mols)
    return smiles

class PygOurDataset(InMemoryDataset):
    """Load datasets."""

    def __init__(
        self,
        root: str = "dataset",
        phase: str = "train",
        dataname: str = "hiv",
        smiles2graph: Callable = smiles2graph,
        transform=None,
        pre_transform=None,
    ):
        """
        Args:
            root (str, optional): The local position of the dataset. Defaults to "dataset".
            phase (str, optional): The data is train, validation or test set. Defaults to "train".
            dataname (str, optional): The name of the dataset. Defaults to "hiv".
            smiles2graph (Callable, optional): Generate the molecular graph from the SMILES
                string. Defaults to smiles2graph.
        """

        self.original_root = root
        self.smiles2graph = smiles2graph
        self.folder = os.path.join(root, dataname)
        self.version = 1
        self.dataname = dataname
        self.phase = phase
        self.aug = "none"
        

        self.tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM")

        self.tokenizer_simple = SmilesTokenizer('utils/vocab.txt')
        self.transform_pe = AddRandomWalkPE(walk_length=5) 

        cutoff = 10
        num_gaussians = 50
        self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)

        super(PygOurDataset, self).__init__(self.folder, transform, pre_transform)

        self.data, self.slices = torch.load(self.processed_paths[0])
    @property
    def raw_file_names(self):
        """Return the name of the raw file."""
        return self.phase + "_" + self.dataname + ".csv"

    @property
    def processed_file_names(self):
        """Return the name of the processed file."""
        return self.phase + "_" + self.dataname + ".pt"

    def process(self):
        """Generate the processed file from the raw file. Only execute when the data is loaded firstly."""
        data_df = pd.read_csv(
            os.path.join(self.raw_dir, self.phase + "_" + self.dataname + ".csv")
        )

        data_df['Y'] = data_df['Y']
        data_df["Drug"] = data_df["Drug"].astype(str)  # 确保所有值为字符串
        data_df = data_df[~data_df["Drug"].str.lower().isin(["nan"])]  # 过滤 "nan"
        data_df = data_df.dropna(subset=["Drug"])
        smiles_list = data_df["Drug"]

        descriptor_list = data_df[filtered_desc_names]

        selected_columns = [
                'Y'
        ]
        homolumogap_list = data_df[selected_columns]

        max_len = 100
        tokenized_smiles = []
        for smiles in smiles_list:
            print(smiles)
            tokenized_smiles.append(self.tokenizer_simple.encode(smiles))
        truncated_padded_sequences = [
                seq[:max_len] + [0] * (max_len - len(seq))
                if len(seq) < max_len else seq[:max_len]
                for seq in tokenized_smiles
        ]

        max_len_gpt = 50

        encodings = self.tokenizer(smiles_list.tolist(), truncation=True, padding=True)

        print("Converting SMILES strings into graphs...")
        data_list = []

        for i in tqdm(range(len(smiles_list))):
            data = Data()

            smiles = smiles_list[i]
            canon_smiles = canonical_smiles(smiles)

            molgraph = Chem.MolFromSmiles(smiles)
            #molgraph = Chem.AddHs(molgraph)

            #MolMass = Descriptors.ExactMolWt(molgraph)
            #tpsa = rdMolDescriptors.CalcTPSA(molgraph)
            #logp = Descriptors.MolLogP(molgraph)
            #hbd = rdMolDescriptors.CalcNumHBD(molgraph)
            #hba = rdMolDescriptors.CalcNumHBA(molgraph)
            #volume = rdMolDescriptors.CalcExactMolWt(molgraph)
            #sasa = rdFreeSASA.CalcSASA(molgraph)

            #Mol_descriptors,desc_names = RDkit_descriptors([canon_smiles])
            #Mol_descriptors = list(Mol_descriptors[0])
            Mol_descriptors = descriptor_list.iloc[i]

            #calc = Calculator(descriptors, ignore_3D=False)
            #moredesp = calc(Chem.MolFromSmiles(smiles))

            params = AllChem.ETKDGv3()
            params.randomSeed = 42
            if AllChem.EmbedMolecule(molgraph, params) != 0:
                print(smiles)
                continue
                raise ValueError("ETKDG method failed to generate conformer")

            conID = molgraph.GetConformer()
            pos = torch.Tensor(conID.GetPositions())

            homolumogap = homolumogap_list.iloc[i]

            graph = self.smiles2graph(smiles)
            sorted_order_b = bfs_torch_geometric(graph['edge_index'])
            sorted_order_d = dfs_torch_geometric(graph['edge_index'])

            if len(sorted_order_b) != len(graph["node_feat"]):
                for k in range(len(graph["node_feat"])):
                    if k not in sorted_order_b:
                        sorted_order_b.append(k)

            if len(sorted_order_d) != len(graph["node_feat"]):
                for k in range(len(graph["node_feat"])):
                    if k not in sorted_order_d:
                        sorted_order_d.append(k)

            row, col = graph['edge_index']
            edge_weight = (pos[row] - pos[col]).norm(dim=-1)
            edge_dis = self.distance_expansion(edge_weight)

            fgs, clusters, atom_features, bond_list, bond_features, fg_features, fg_edge_list, fg_edge_features, atom2fg_list = mol_to_graphs(smiles)

            atom2fg_list = sorted(atom2fg_list, key=lambda x: x[0])
            rdkit_mol = AllChem.MolFromSmiles(smiles)

            mgf, mgf_fp = getmorganfingerprint(rdkit_mol)
            maccs = getmaccsfingerprint(rdkit_mol)
            avalon = pyAvalonTools.GetAvalonFP(rdkit_mol)
            avalon = [int(b) for b in avalon.ToBitString()]

            assert len(graph["edge_feat"]) == graph["edge_index"].shape[1]
            assert len(graph["node_feat"]) == graph["num_nodes"]

            data.__num_nodes__ = len(fg_features) #int(graph["num_nodes"])
            if len(fg_edge_list) > 0:
                fg_edge_index = np.vstack(fg_edge_list).T
            else:
                fg_edge_index = np.empty((2, 0), dtype=int)

            sorted_order_b_fg = bfs_torch_geometric(fg_edge_index)
            sorted_order_d_fg = dfs_torch_geometric(fg_edge_index)

            if len(sorted_order_b_fg) != len(fg_features):
                for k in range(len(fg_features)):
                    if k not in sorted_order_b_fg:
                        sorted_order_b_fg.append(k)

            if len(sorted_order_d_fg) != len(fg_features):
                for k in range(len(graph["node_feat"])):
                    if k not in sorted_order_d_fg:
                        sorted_order_d_fg.append(k)
            
            if len(fg_features) == 0:
                fg_features = [[0] * 12]

            epsilon = 0.0000001
            data.edge_index = torch.Tensor(graph["edge_index"]).to(torch.int64) #torch.from_numpy(graph["edge_index"]).to(torch.int64)
            data.edge_attr = torch.Tensor(bond_features).to(torch.int64) #torch.from_numpy(graph["edge_feat"]).to(torch.int64)
            data.edge_dis = torch.Tensor(edge_dis)
            data.x = torch.Tensor(atom_features).to(torch.int64) #torch.from_numpy(graph["node_feat"]).to(torch.int64)
            data.y = torch.Tensor(homolumogap_list.iloc[i])
            data.input_ids = torch.Tensor(encodings.input_ids[i])
            data.attention_mask = torch.Tensor(encodings.attention_mask[i])
            data.token_ids = torch.Tensor(truncated_padded_sequences[i])
            #data.gpt_ids = torch.Tensor(gpt_seq[i])
            data.sorted_order_b = torch.Tensor(sorted_order_b).to(torch.int64)
            data.sorted_order_d = torch.Tensor(sorted_order_d).to(torch.int64)
            data.sorted_order_b_fg = torch.Tensor(sorted_order_b_fg).to(torch.int64)
            data.sorted_order_d_fg = torch.Tensor(sorted_order_d_fg).to(torch.int64)
            data.pos = pos
            data.atom2fg_list = torch.Tensor(atom2fg_list).to(torch.int64)
            data.clusters = torch.Tensor(clusters).to(torch.int64)
            data.fg_x = torch.Tensor(fg_features)
            data.fg_edge = torch.Tensor(fg_edge_list).to(torch.int64)
            data.fg_edge_attr = torch.Tensor(fg_edge_features).to(torch.int64)
            data.mgf = torch.tensor(mgf)
            data.maccs = torch.tensor(maccs)
            data.avalon = torch.tensor(avalon)
            data.desp = torch.tensor(Mol_descriptors)
            #data.mass = torch.tensor(MolMass)
            #data.tpsa = torch.tensor(tpsa)
            #data.logp = torch.tensor(logp)
            #data.hbd = torch.tensor(hbd)
            #data.hba = torch.tensor(hba)
            data.sv = torch.tensor(similarity_vector)
            data.smiles = smiles
            #data.geom3d_feature = torch.tensor(unimol_feature[smiles])#torch.tensor(self.geom3d.get_repr(smiles)).squeeze()[0]
            data_list.append(data)
            #data.grover = torch.tensor(grover_feature[smiles])

        '''
        y_values = torch.cat([data.y for data in data_list]).numpy().flatten()
        kde = gaussian_kde(y_values, bw_method=0.2)
        density = kde(y_values)
        
        weights = np.log(1.0 / (density + 1e-6))
        #weights = weights / np.max(weights)
        weights = 1 + (weights - np.min(weights)) / (np.max(weights) - np.min(weights))

        for i in range(len(data_list)):
            data_list[i].weight = torch.tensor([weights[i]])
            print(data_list[i].weight)
        '''

        data_list = [self.transform_pe(data) for data in data_list]
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)

        print("Saving...")
        torch.save((data, slices), self.processed_paths[0])

    def get(self, idx: int):
        """Get the idx-th data.
        Args:
            idx (int): The number of the data.
        """
        data = Data()
        for key in self.data.keys():
            item, slices = self.data[key], self.slices[key]
            if key=='smiles':
                    data[key] = item[idx]
            else:
                    s = list(repeat(slice(None), item.dim()))
                    s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
                    data[key] = item[s]
        random_integer = random.randint(0, len(data.fg_x)-1)
        data.idx = idx

        ''''
        sorted_order_b_fg = dfs_torch_geometric(data.fg_edge.T, start_node = random_integer)
        if len(sorted_order_b_fg) != len(data.fg_x):
            for i in range(len(data.fg_x)):
                if i not in sorted_order_b_fg:
                    sorted_order_b_fg.append(i)
        data.sorted_order_b_fg = torch.Tensor(sorted_order_b_fg).to(torch.int64)

        random_integer = random.randint(0, len(data.x) - 1)
        sorted_order_b = dfs_torch_geometric(data.edge_index, start_node = random_integer)
        if len(sorted_order_b) != len(data.x):
            for i in range(len(data.x)):
                if i not in sorted_order_b:
                    sorted_order_b.append(i)
        data.sorted_order_b = torch.Tensor(sorted_order_b).to(torch.int64)
        '''
        return data
