import os
from data_process import generate_molecule_nodeset
import torch
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from tqdm import trange, tqdm

DATASET = {
    'tox21': '/data/haotian/LLaGA/dataset/tox21',
    'qm9': '/data/haotian/LLaGA/dataset/qm9',
    'zinc12k': '/data/haotian/LLaGA/dataset/zinc12k'
}

transform = RandomLinkSplit(is_undirected=True)

for dataset in ['zinc12k']:
    dataset_folder = DATASET[dataset]
    processed_data = torch.load(os.path.join(dataset_folder, 'processed_data.pt'))
    train, val, test = transform(processed_data)
    
    train_molecule_nodeset = generate_molecule_nodeset(train)
    val_molecule_nodeset = generate_molecule_nodeset(val)
    test_molecule_nodeset = generate_molecule_nodeset(test)
    
    torch.save(train_molecule_nodeset, os.path.join(dataset_folder, 'train_molecule_nodeset.pt'))
    torch.save(val_molecule_nodeset, os.path.join(dataset_folder, 'val_molecule_nodeset.pt'))
    torch.save(test_molecule_nodeset, os.path.join(dataset_folder, 'test_molecule_nodeset.pt'))
    
    torch.save(train.edge_label, os.path.join(dataset_folder, 'train_edge_label.pt'))
    torch.save(val.edge_label, os.path.join(dataset_folder, 'val_edge_label.pt'))
    torch.save(test.edge_label, os.path.join(dataset_folder, 'test_edge_label.pt'))
    
    split = {
        'train': train,
        'val': val,
        'test': test
    }
    
    nodesets = {
        'train': train_molecule_nodeset,
        'val': val_molecule_nodeset,
        'test': test_molecule_nodeset
    }
    
    # print(train)
    # print(len(train_molecule_nodeset))
    # exit()
    
    for key, value in split.items():
        left_edge_index, right_edge_index = [], []
        left_atom_type, right_atom_type = [], []
        
        left_nodes, right_nodes = value.edge_label_index
        left_samples ,right_samples = [], []
        for i , nodes in tqdm(enumerate(zip(left_nodes, right_nodes)), desc=f'Processing {dataset} {key}'):
        # for i in trange(range(len(left)), desc=f'Processing {dataset} {key}'):
            left, right = nodes
            left_nodeset = [left.item()] + nodesets[key][left.item()]
            right_nodeset = [right.item()] + nodesets[key][right.item()]
            # print(left_nodeset)
            # print(right_nodeset)
            # exit()
            
            left_edge_index, _ = subgraph(left_nodeset, value.edge_index, num_nodes=len(nodesets[key]), relabel_nodes=True)
            right_edge_index, _ = subgraph(right_nodeset, value.edge_index, num_nodes=len(nodesets[key]), relabel_nodes=True)
            # print(left_edge_index)
            # print(right_edge_index)
            # exit()
            
            left_atom_type, right_atom_type = value.atom_type[torch.LongTensor(left_nodeset)], value.atom_type[torch.LongTensor(right_nodeset)]
            # print(left_atom_type)
            # print(right_atom_type)
            
            left_samples.append(Data(x=left_atom_type, edge_index=left_edge_index, num_nodes=left_atom_type.size(0)))
            right_samples.append(Data(x=right_atom_type, edge_index=right_edge_index, num_nodes=right_atom_type.size(0)))
            
            # print(left_samples[-1].x, left_samples[-1].edge_index)
            # print(right_samples[-1].x, right_samples[-1].edge_index)
            # exit()
        
        torch.save(left_samples, os.path.join(dataset_folder, f'left_{key}_samples.pt'))
        torch.save(right_samples, os.path.join(dataset_folder, f'right_{key}_samples.pt'))
            
        
    
    

