import os.path as osp
import os
import torch
import numpy as np
from torch_geometric.data import InMemoryDataset, Data

class MolecularDIRDataset(InMemoryDataset):
    
    splits = ['train', 'val', 'test']
    
    def __init__(self, root, mode='train', transform=None, pre_transform=None, pre_filter=None):
        assert mode in self.splits
        self.mode = mode
        
        
        super(MolecularDIRDataset, self).__init__(root, transform, pre_transform, pre_filter)
        
        idx = self.processed_file_names.index(f'Molecular_DIR_{mode}.pt')
        self.data, self.slices = torch.load(self.processed_paths[idx])
    
    @property
    def raw_dir(self):
        
        raw = osp.join(self.root, 'raw')
        if osp.exists(raw) and os.path.isdir(raw):
            return raw
        return self.root


    @property
    def processed_dir(self):
        
        return osp.join(self.root, 'processed')
        
    @property
    def raw_file_names(self):
        
        return [f'{split}_crcg.npy' for split in self.splits]
    
    @property
    def processed_file_names(self):
        
        return [f'Molecular_DIR_{split}.pt' for split in self.splits]
    
    def download(self):
        
        pass
    
    def process(self):
        import shutil

        
        for split in self.splits:
            src = osp.join(self.root, f"{split}_crcg.npy")
            tgt = osp.join(self.raw_dir, f"{split}_crcg.npy")
            if osp.exists(src) and not osp.exists(tgt):
                print(f"Copying {src} → {tgt}")
                os.makedirs(self.raw_dir, exist_ok=True)
                shutil.copy(src, tgt)

        
        for mode in self.splits:
            raw_path = osp.join(self.raw_dir, f"{mode}_crcg.npy")
            if not osp.exists(raw_path):
                print(f"Raw file {raw_path} not found, skipping processing for {mode}.")
                continue

            
            loaded = np.load(raw_path, allow_pickle=True)
            
            if isinstance(loaded, np.ndarray):
                
                if loaded.ndim == 0 and isinstance(loaded.item(), dict):
                    d = loaded.item()
                    nf_list = d.get('node_features', d.get('pos'))
                    ei_list = d['edge_index']
                    lbl_list = d['label']
                    np_data_list = [
                        {'node_features': nf, 'edge_index': ei, 'label': l}
                        for nf, ei, l in zip(nf_list, ei_list, lbl_list)
                    ]
                else:
                    
                    try:
                        np_data_list = loaded.tolist()
                    except Exception:
                        raise TypeError(f"Unsupported np.ndarray format in {raw_path}")
            elif isinstance(loaded, dict):
                
                d = loaded
                nf_list = d.get('node_features', d.get('pos'))
                ei_list = d['edge_index']
                lbl_list = d['label']
                np_data_list = [
                    {'node_features': nf, 'edge_index': ei, 'label': l}
                    for nf, ei, l in zip(nf_list, ei_list, lbl_list)
                ]
            elif isinstance(loaded, list):
                
                np_data_list = loaded
            else:
                raise TypeError(f"Unsupported data format in {raw_path}")

            data_list = []
            
            for idx, graph_data in enumerate(np_data_list):
                nf = graph_data['node_features']
                ei = graph_data['edge_index']
                lbl = graph_data['label']

                x = torch.tensor(nf, dtype=torch.float)
                edge_index = torch.tensor(ei, dtype=torch.long)
                if edge_index.dim() == 2 and edge_index.size(0) != 2:
                    edge_index = edge_index.t().contiguous()

                
                y = torch.tensor(lbl, dtype=torch.long)
                if y.dim() == 0:
                    y = y.unsqueeze(0)

                data = Data(x=x, edge_index=edge_index, y=y)

                
                if hasattr(data, 'validate') and callable(data.validate):
                    try:
                        data.validate(raise_on_error=True)
                    except ValueError as e:
                        print(f"Data validation failed for graph {idx} of {mode}.npy: {e}. Skipping.")
                        continue
                if self.pre_filter is not None and not self.pre_filter(data):
                    continue
                if self.pre_transform is not None:
                    data = self.pre_transform(data)

                data_list.append(data)

            
            torch.save(self.collate(data_list), osp.join(self.processed_dir, f"Molecular_DIR_{mode}.pt"))

        print("Done!")



























