import os
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data

class LaplacianDataset(Dataset):
    def __init__(self, root):
        """
        Initializes the dataset with the directory containing .pt files.
        
        Parameters:
        directory (str): Path to the directory containing .pt files.
        """
        directory = root
        self.directory = directory
        # List of all .pt files in the directory
        self.files = sorted([f for f in os.listdir(directory) if f.endswith('.pt')])

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.files)

    def __getitem__(self, index):
        """
        Loads and returns the data sample from file `index.pt`.
        
        Parameters:
        index (int): Index of the file to retrieve.
        
        Returns:
        torch.Tensor or dict: The data loaded from the .pt file.
        """
        # Construct the file path
        file_path = os.path.join(self.directory, self.files[index])
        # Load the .pt file (it could be a tensor or a dictionary of tensors)
        data = torch.load(file_path)
        x = data['x']
        y = data['y']
        return x,y
    
class CustomGraphData(Dataset):
    def __init__(self, root):
        """
        Initializes the dataset with the directory containing .pt files.
        
        Parameters:
        directory (str): Path to the directory containing .pt files.
        """
        directory = root
        self.directory = directory
        # List of all .pt files in the directory
        self.files = sorted([f for f in os.listdir(directory) if f.endswith('.pt')])
        index = 0
        file_path = os.path.join(self.directory, self.files[index])
        data = torch.load(file_path)
        # print('data', data)
        # self.num_node_features = data['x'].shape[1]
        self.num_node_features = 1
        
    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.files)

    def __getitem__(self, index):
        """
        Loads and returns the data sample from file `index.pt`.
        
        Parameters:
        index (int): Index of the file to retrieve.
        
        Returns:
        torch.Tensor or dict: The data loaded from the .pt file.
        """
        # Construct the file path
        file_path = os.path.join(self.directory, self.files[index])
        # Load the .pt file (it could be a tensor or a dictionary of tensors)
        data = torch.load(file_path)
        x = data['x'].float().reshape(-1, 1)
        y = data['y']
        edge_index = data['edge_indexx'].long()
        n = x.shape[0]

        # Create a mask for columns where all values are <= n
        mask = (edge_index <= n).all(dim=0)

        # Apply the mask to keep only those columns
        edge_index_filtered = edge_index[:, mask]
        data_return = Data(x=x, edge_index=edge_index_filtered, y=y).to(device='cpu')
        return data_return