import os
import glob
import json
import random
import torch
import pickle
import numpy as np
import os.path as osp
from torch_geometric.datasets import MoleculeNet
from torch_geometric.utils import dense_to_sparse
from torch.utils.data import random_split, Subset
from torch_geometric.data import Data, InMemoryDataset, DataLoader
import torch.nn.functional as F
import pandas as pd
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
from rdkit.Chem import rdmolops

def smiles_to_pyg_data(smiles,y):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"无法解析 SMILES: {smiles}")
    
    node_features = []
    for atom in mol.GetAtoms():
        
        features = [
            atom.GetAtomicNum(),  
            atom.GetTotalDegree(),  
            atom.GetFormalCharge(),  
            atom.GetNumRadicalElectrons(), 
            atom.GetHybridization().real, 
            atom.GetIsAromatic(),  
        ]
        node_features.append(features)
    
    
    x = torch.tensor(node_features, dtype=torch.float)
    
    
    edge_index = []
    edge_attr = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        
        
        edge_index.append([i, j])
        edge_index.append([j, i])
        

        bond_type = bond.GetBondTypeAsDouble()
        edge_attr.append([bond_type])
        edge_attr.append([bond_type])
    
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,y=torch.tensor([y]).long())
    return data

class ZincFluorTrainDataset(InMemoryDataset):
    def __init__(self, root, name, transform=None, pre_transform=None):
        self.root = root
        name = name+'train'
        self.name = name.upper()
        super(ZincFluorTrainDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)
        self.img_num_list = [(self.data.y == i).sum().item() for i in range(max(self.data.y)+1) ]

    def __len__(self):
        return len(self.slices['x']) - 1
    
    
    def __getitem__(self, idx):
        """
        重写获取样本的方法，支持获取原始数据或增强后数据
        """
        data = self.get(idx)
        return data,self.transform(data) if self.transform else data,self.transform(data) if self.transform else data

    @property
    def raw_dir(self):
        return os.path.join(self.root, self.name, 'raw')

    @property
    def raw_file_names(self):
        return ['fluor_colour_subset_with_color_value_train.csv','fluor_colour_subset_with_color_value_test.csv']

    @property
    def processed_dir(self):
        return os.path.join(self.root, self.name, 'processed')

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        csv_path = os.path.join(self.raw_dir, 'fluor_colour_subset_with_color_value_train.csv')
        df = pd.read_csv(csv_path)
        SMILES = df.loc[ : ,"SMILES"].tolist()
        color_values = df.loc[ : ,"color_value"].tolist()
        from tqdm import tqdm
        data_list = []
        for i in tqdm(range(len(SMILES))):
            data_list.append(smiles_to_pyg_data(SMILES[i],color_values[i]))

        torch.save(self.collate(data_list), self.processed_paths[0])


class ZincFluorValDataset(InMemoryDataset):
    def __init__(self, root, name, transform=None, pre_transform=None):
        self.root = root
        name = name+'test'
        self.name = name.upper()
        super(ZincFluorValDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

    def __len__(self):
        return len(self.slices['x']) - 1
    
    def __getitem__(self, idx):
        """
        重写获取样本的方法，支持获取原始数据或增强后数据
        """
        data = self.get(idx)
        return data,idx

    @property
    def raw_dir(self):
        return os.path.join(self.root, self.name, 'raw')

    @property
    def raw_file_names(self):
        return ['fluor_colour_subset_with_color_value_train.csv','fluor_colour_subset_with_color_value_test.csv']

    @property
    def processed_dir(self):
        return os.path.join(self.root, self.name, 'processed')

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        csv_path = os.path.join(self.raw_dir, 'fluor_colour_subset_with_color_value_test.csv')
        df = pd.read_csv(csv_path)
        SMILES = df.loc[ : ,"SMILES"].tolist()
        color_values = df.loc[ : ,"color_value"].tolist()
        from tqdm import tqdm
        data_list = []
        for i in tqdm(range(len(SMILES))):
            data_list.append(smiles_to_pyg_data(SMILES[i],color_values[i]))

        torch.save(self.collate(data_list), self.processed_paths[0])



from torch_geometric.transforms import Compose
from torch_geometric.transforms import BaseTransform

class RandomNodeDrop(BaseTransform):
    def __init__(self, p=0.1):
        self.p = p  # 节点丢弃的概率

    def __call__(self, data):
        # 获取节点数量
        num_nodes = data.num_nodes
        
        # 确定要保留的节点
        mask = torch.rand(num_nodes) >= self.p
        
        # 如果所有节点都被丢弃，至少保留一个节点
        if mask.sum() == 0:
            mask[torch.randint(0, num_nodes, (1,))] = True
            
        # 应用节点掩码
        data.x = data.x[mask]
        
        # 调整边索引
        edge_mask = mask[data.edge_index[0]] & mask[data.edge_index[1]]
        data.edge_index = data.edge_index[:, edge_mask]
        # 处理边权重和其他边属性
        if hasattr(data, 'edge_attr'):
            data.edge_attr = data.edge_attr[edge_mask]
        # 重新映射边索引
        node_map = torch.full((num_nodes,), -1, dtype=torch.long)
        node_map[mask] = torch.arange(mask.sum())
        data.edge_index = node_map[data.edge_index]
        
        return data


class AddNoise:
    def __init__(self, mean=0.0, std=0.1):
        self.mean = mean
        self.std = std
        
    def __call__(self, data):
        data.x = data.x + torch.randn_like(data.x) * self.std + self.mean
        return data


def get_ZincFluor(root='/data/zz/LOC_LT/LOS/data',args=None):
    transform = Compose([
        AddNoise(std=0.1),
    ])
    train_dataset = ZincFluorTrainDataset(root=root, name='zinc_fluor',transform = transform)
    test_dataset = ZincFluorValDataset(root=root, name='zinc_fluor')
    return train_dataset, test_dataset

if __name__ == "__main__":
    from torch_geometric.nn import GCNConv, global_mean_pool
    from torch_geometric.loader import DataLoader
    from torch.utils.data import random_split

    class GCNClassifier(torch.nn.Module):
        def __init__(self, input_dim, hidden_dim, output_dim, dropout_p=0.5):
            super(GCNClassifier, self).__init__()
            self.conv1 = GCNConv(input_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, hidden_dim)
            self.classifier = torch.nn.Linear(hidden_dim, output_dim)
            self.dropout_p = dropout_p

        def forward(self, x, edge_index, batch):
            x = self.conv1(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout_p, training=self.training)
            x = self.conv2(x, edge_index)
            x = F.relu(x)
            x = global_mean_pool(x, batch)
            x = F.dropout(x, p=self.dropout_p, training=self.training)
            x = self.classifier(x)
            return x

    def train(model, train_loader, optimizer, criterion, device):
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index, batch.batch)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs
        return total_loss / len(train_loader.dataset)

    def evaluate(model, loader, criterion, device):
        model.eval()
        total_loss = 0
        correct = 0
        with torch.no_grad():
            for batch in loader:
                batch = batch.to(device)
                out = model(batch.x, batch.edge_index, batch.batch)
                loss = criterion(out, batch.y)
                total_loss += loss.item() * batch.num_graphs
                pred = out.argmax(dim=1)
                correct += int((pred == batch.y).sum())
        return total_loss / len(loader.dataset), correct / len(loader.dataset)

    def main():
        torch.manual_seed(42)
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        dataset = ZincFluorTrainDataset(root='/data/zz/LOC_LT/LOS/data', name='zinc_fluor')
        
        input_dim = dataset.num_node_features
        output_dim = dataset.num_classes 
        
        train_size = int(0.8 * len(dataset))
        test_size = len(dataset) - train_size
        train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
        
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32)
        
        model = GCNClassifier(input_dim, 128, output_dim).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = torch.nn.CrossEntropyLoss()
        
        num_epochs = 100
        for epoch in range(1, num_epochs + 1):
            train_loss = train(model, train_loader, optimizer, criterion, device)
            test_loss, test_acc = evaluate(model, test_loader, criterion, device)
            print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
        
        torch.save(model.state_dict(), 'gcn_model.pth')

    main() 




