import os.path as osp
import torch
import numpy as np
from torch_geometric.data import InMemoryDataset, Data

class MolecularDataset(InMemoryDataset):
    
    splits = ['train', 'val', 'test']
    
    def __init__(self, root, mode='train', transform=None, pre_transform=None, pre_filter=None):
        assert mode in self.splits
        self.mode = mode
        super(MolecularDataset, self).__init__(root, transform, pre_transform, pre_filter)
        idx = self.processed_file_names.index('Molecular_{}.pt'.format(mode))
        self.data, self.slices = torch.load(self.processed_paths[idx])
    
    @property
    def raw_file_names(self):
        return ['train_crcg.npy', 'val_crcg.npy', 'test_crcg.npy']
    
    @property
    def processed_file_names(self):
        return ['Molecular_train.pt', 'Molecular_val.pt', 'Molecular_test.pt']
    
    def download(self):
        
        pass
    
    def process(self):
        idx = self.raw_file_names.index('{}_crcg.npy'.format(self.mode))
        np_data = np.load(osp.join(self.raw_dir, self.raw_file_names[idx]), allow_pickle=True).item()
        
        
        node_features = np_data.get('node_features', np_data.get('pos'))
        edge_indices = np_data['edge_index']
        labels = np_data['label']
        
        data_list = []
        for idx, (feats, edges, label) in enumerate(zip(node_features, edge_indices, labels)):
            
            feats = torch.tensor(feats, dtype=torch.float)
            
            
            edges = edges.clone().detach() if isinstance(edges, torch.Tensor) else torch.tensor(edges, dtype=torch.long)
            if edges.ndim == 2 and edges.size(0) != 2:
                edges = edges.t()
            edges = edges.contiguous()
            
            
            node_idx = torch.unique(edges.flatten())
            num_nodes = node_idx.size(0)
            max_node_idx = torch.max(node_idx).item()
            if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
                num_nodes = max_node_idx + 1
                node_idx = torch.arange(num_nodes, dtype=torch.long, device=edges.device)
            
            
            assert node_idx.max() == node_idx.size(0) - 1
            
            
            y = torch.tensor(label, dtype=torch.long).unsqueeze(dim=0)
            
            
            edge_attr = torch.ones(edges.size(1), 1)
            
            
            z = np.zeros(num_nodes, dtype=np.int64)
            
            
            data = Data(
                x=feats,
                y=y,
                z=z,
                edge_index=edges,
                edge_attr=edge_attr,
                name=f'Molecular-{self.mode}-{idx}',
                idx=idx
            )
            
            
            
            data.edge_gt_att = torch.zeros(edges.size(1), dtype=torch.long)
            
            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)
            
            data_list.append(data)
        
        idx = self.processed_file_names.index('Molecular_{}.pt'.format(self.mode))
        torch.save(self.collate(data_list), self.processed_paths[idx])
