import os
import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.utils import degree

class PaperDIRDataset(InMemoryDataset):
    def __init__(self, root, mode='train', transform=None, pre_transform=None):
        
        self.root = os.path.abspath(root)
        self.mode = mode
        super(PaperDIRDataset, self).__init__(self.root, transform, pre_transform)
        path = self.processed_paths[0]
        self.data, self.slices = torch.load(path)

    @property
    def raw_file_names(self):
        return [f"{self.mode}_casual_1_3.npy"]

    @property
    def processed_file_names(self):
        return [f"Paper_DIR_{self.mode}.pt"]

    @property
    def raw_dir(self):
        raw_path = os.path.join(self.root, 'raw')
        return raw_path if os.path.isdir(raw_path) else self.root

    def download(self):
        pass

    def process(self):
        import shutil
        src = os.path.join(self.root, f"{self.mode}_casual_1_3.npy")
        dst = os.path.join(self.raw_dir, f"{self.mode}_casual_1_3.npy")
        if os.path.exists(src) and not os.path.exists(dst):
            os.makedirs(self.raw_dir, exist_ok=True)
            shutil.copy(src, dst)

        raw_path = os.path.join(self.raw_dir, f"{self.mode}_casual_1_3.npy")
        if not os.path.exists(raw_path):
            raise FileNotFoundError(f"Raw file {raw_path} not found")

        np_data = np.load(raw_path, allow_pickle=True).item()
        edge_indices    = np_data.get('edge_index')
        features_list   = np_data.get('features', None)
        labels_list     = np_data.get('role_id', np_data.get('label', np_data.get('ground_truth')))
        if edge_indices is None or labels_list is None:
            raise KeyError("NPY file must contain 'edge_index' and 'role_id'/'label'/'ground_truth' keys")

        
        if not isinstance(edge_indices, (list, np.ndarray)):
            edge_indices = [edge_indices]
        if not isinstance(labels_list, (list, np.ndarray)):
            labels_list = [labels_list]
        if features_list is not None and not isinstance(features_list, (list, np.ndarray)):
            features_list = [features_list]

        
        length = len(labels_list)
        edge_indices  = edge_indices[:length]
        labels_list   = labels_list[:length]
        if features_list is not None:
            features_list = features_list[:length]

        data_list = []
        for i, edges in enumerate(edge_indices):
            
            ei = torch.tensor(edges, dtype=torch.long) if not isinstance(edges, torch.Tensor) else edges.long()
            if ei.ndim == 1:
                ei = ei.view(2, -1)
            elif ei.ndim == 2 and ei.size(0) != 2 and ei.size(1) == 2:
                ei = ei.t()
            ei = ei.contiguous()

            
            if features_list is not None:
                num_nodes = len(features_list[i])
            else:
                lab = labels_list[i]
                num_nodes = len(lab) if hasattr(lab, '__len__') else int(ei.max().item()) + 1
            num_nodes = int(num_nodes)

            
            if features_list is not None:
                feats = features_list[i]
                x = torch.tensor(feats, dtype=torch.float) if not isinstance(feats, torch.Tensor) else feats.float()
                if x.size(0) < num_nodes:
                    pad = torch.zeros((num_nodes - x.size(0), x.size(1)))
                    x = torch.cat([x, pad], dim=0)
                elif x.size(0) > num_nodes:
                    x = x[:num_nodes]
            else:
                deg = degree(ei[0], num_nodes=num_nodes, dtype=torch.long)
                x = deg.view(-1, 1).float()

            
            lab = labels_list[i]
            y = torch.tensor(lab, dtype=torch.long) if not isinstance(lab, torch.Tensor) else lab.long()
            y = y.view(-1)
            if y.numel() < num_nodes:
                pad = torch.zeros(num_nodes - y.numel(), dtype=torch.long)
                y = torch.cat([y, pad], dim=0)
            elif y.numel() > num_nodes:
                y = y[:num_nodes]

            data_list.append(Data(x=x, edge_index=ei, y=y))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
