import os
import os.path as osp
import shutil
import torch
import numpy as np
from torch_geometric.data import InMemoryDataset, Data
import torch_geometric.utils as pyg_utils

class PaperDataset(InMemoryDataset):
    
    splits = ['train', 'val', 'test']

    def __init__(self, root, mode='train', transform=None, pre_transform=None, pre_filter=None):
        assert mode in self.splits, f"Mode must be one of {self.splits}"
        self.mode = mode
        
        processed_dir = osp.join(root, 'processed')
        if osp.exists(processed_dir):
            shutil.rmtree(processed_dir)
        super().__init__(root, transform, pre_transform, pre_filter)
        
        idx = self.processed_file_names.index(f'Paper_{mode}.pt')
        self.data, self.slices = torch.load(self.processed_paths[idx])

    @property
    def raw_file_names(self):
        
        return [f'{m}_casual_1_3.npy' for m in self.splits]

    @property
    def processed_file_names(self):
        
        return [f'Paper_{m}.pt' for m in self.splits]

    def download(self):
        
        pass

    def process(self):
        
        idx = self.raw_file_names.index(f'{self.mode}_casual_1_3.npy')
        raw_path = osp.join(self.raw_dir, self.raw_file_names[idx])
        np_data = np.load(raw_path, allow_pickle=True).item()

        
        edge_idxs     = np_data['edge_index']
        features_list = np_data.get('features')
        roles_list    = np_data.get('role_id')

        if features_list is None or roles_list is None:
            raise KeyError("NPY file must contain 'edge_index', 'features' and 'role_id' keys")

        data_list = []
        for edges, feats, roles in zip(edge_idxs, features_list, roles_list):
            ei = torch.tensor(edges, dtype=torch.long)
            if ei.ndim == 2 and ei.size(0) != 2:
                ei = ei.t()
            ei = ei.contiguous()

            num_nodes = int(feats.shape[0])
            x = torch.tensor(feats, dtype=torch.float)

            y = torch.tensor(roles, dtype=torch.long)
            if y.numel() < num_nodes:
                pad = torch.zeros(num_nodes - y.numel(), dtype=torch.long)
                y = torch.cat([y, pad], dim=0)
            elif y.numel() > num_nodes:
                y = y[:num_nodes]

            
            num_edges = ei.size(1)
            edge_attr = torch.ones(num_edges, dtype=torch.float)

            data_list.append(Data(
                x=x,
                edge_index=ei,
                edge_attr=edge_attr,
                y=y
            ))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[idx])
