

import numpy as np
import os
import torch
from torch_geometric.data import Data, InMemoryDataset
from itertools import repeat

class MolecularDataset(InMemoryDataset):
    def __init__(self, npy_path, transform=None):
        
        self.npy_path = os.path.abspath(npy_path)
        
        root = os.path.dirname(self.npy_path)
        super(MolecularDataset, self).__init__(root=root, transform=transform)

        
        self.data, self.slices = self.load_npy(self.npy_path)
        self._num_features = self.data.x.size(1) if self.data.x is not None else 0

    @property
    def raw_file_names(self):
        
        return [os.path.basename(self.npy_path)]

    @property
    def processed_file_names(self):
        
        return [os.path.basename(self.npy_path).replace('.npy', '.pt')]

    def process(self):
        
        pass

    @property
    def num_features(self):
        
        
        if not hasattr(self, '_num_features'):
             
             
             self.data, self.slices = self.load_npy(self.npy_path)
             self._num_features = self.data.x.size(1) if self.data.x is not None else 0
        return self._num_features

    def load_npy(self, npy_path):
        
        try:
            np_data = np.load(npy_path, allow_pickle=True).item()
        except FileNotFoundError:
            print(f"Error: NPY file not found at {npy_path}")
            raise
        except Exception as e:
            print(f"Error loading NPY file {npy_path}: {e}")
            raise

        
        
        node_features_key = 'node_features' if 'node_features' in np_data else 'pos'
        edge_index_key = 'edge_index'
        label_key = 'label'

        if node_features_key not in np_data:
            raise KeyError(f"Node features key ('{node_features_key}') not found in {npy_path}")
        if edge_index_key not in np_data:
            raise KeyError(f"Edge index key ('{edge_index_key}') not found in {npy_path}")
        if label_key not in np_data:
            raise KeyError(f"Label key ('{label_key}') not found in {npy_path}")

        node_features = np_data[node_features_key]
        edge_indices = np_data[edge_index_key]
        labels = np_data[label_key]

        data_list = []
        num_graphs = len(labels)
        if not (len(node_features) == num_graphs and len(edge_indices) == num_graphs):
             print(f"Warning: Mismatch in number of entries in {npy_path}. "
                   f"Labels: {num_graphs}, Features: {len(node_features)}, Edges: {len(edge_indices)}")
             
             min_len = min(num_graphs, len(node_features), len(edge_indices))
             if min_len == 0:
                 raise ValueError(f"Inconsistent data lengths in {npy_path}, cannot create graphs.")
             node_features = node_features[:min_len]
             edge_indices = edge_indices[:min_len]
             labels = labels[:min_len]

        for i in range(len(labels)):
            feats = node_features[i]
            edges = edge_indices[i]
            label = labels[i]

            
            if not isinstance(feats, torch.Tensor):
                feats = torch.tensor(feats, dtype=torch.float)
            else:
                feats = feats.float()

            
            if not isinstance(edges, torch.Tensor):
                edges = torch.tensor(edges, dtype=torch.long)
            else:
                edges = edges.long()

            if edges.ndim == 1:
                 
                 if edges.numel() % 2 == 0:
                     edges = edges.view(2, -1)
                 else:
                     raise ValueError(f"Edge index for graph {i} in {npy_path} has odd number of elements and cannot be reshaped to [2, num_edges]. Shape: {edges.shape}")
            elif edges.ndim == 2:
                if edges.size(0) != 2 and edges.size(1) == 2:
                    
                    edges = edges.t()
                elif edges.size(0) != 2 and edges.size(1) != 2:
                     raise ValueError(f"Edge index for graph {i} in {npy_path} must have shape [2, num_edges] or [num_edges, 2]. Got shape: {edges.shape}")
            else:
                raise ValueError(f"Edge index for graph {i} in {npy_path} has invalid dimension: {edges.ndim}. Shape: {edges.shape}")

            edges = edges.contiguous()

            
            if isinstance(label, (int, float)):
                 
                 y = torch.tensor([label], dtype=torch.long)
            elif isinstance(label, (np.ndarray, list)):
                 
                 y = torch.tensor(label, dtype=torch.long)
                 if y.numel() == 1:
                     y = y.view(1) 
                 
                 
                 
            elif isinstance(label, torch.Tensor):
                 y = label.long()
                 if y.numel() == 1:
                     y = y.view(1)
            else:
                 raise TypeError(f"Unsupported label type for graph {i} in {npy_path}: {type(label)}")

            graph_data = Data(x=feats, edge_index=edges, y=y)
            
            if graph_data.x is None:
                 print(f"Warning: Graph {i} in {npy_path} has no node features.")
            if graph_data.edge_index is None:
                 print(f"Warning: Graph {i} in {npy_path} has no edges.")
            elif graph_data.edge_index.max().item() >= graph_data.num_nodes:
                raise ValueError(
                    f"Error in graph {i} of {npy_path}: edge_index contains node indices, greater than or equal to the number of nodes ({graph_data.num_nodes}). Please check the edge indices.")
            if graph_data.y is None:
                 print(f"Warning: Graph {i} in {npy_path} has no labels.")
            

            data_list.append(graph_data)

        if not data_list:
            print(f"Warning: No valid graphs were loaded from {npy_path}.")
            
            
            
            
            dummy_data = Data(x=torch.empty((0, self._num_features if hasattr(self, 	'_num_features') else 0), dtype=torch.float),
                              edge_index=torch.empty((2, 0), dtype=torch.long),
                              y=torch.empty((0,), dtype=torch.long))
            return self.collate([dummy_data])

        
        
        data, slices = self.collate(data_list)
        return data, slices

    def get(self, idx):
        
        
        data = Data()
        for key in self.data.keys():
            item, slices = self.data[key], self.slices[key]
            s = list(repeat(slice(None), item.dim()))
            s[self.data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx+1])
            data[key] = item[s]
        return data

    def len(self):
        
        
        
        
        if 'y' in self.slices:
            return len(self.slices['y']) - 1
        elif 'x' in self.slices:
             return len(self.slices['x']) - 1
        elif 'edge_index' in self.slices: 
             return len(self.slices['edge_index']) - 1
        else:
             
             
             print("Warning: Could not determine dataset length from slices. Returning 0.")
             return 0












































