import os
import torch
from torch.utils.data import Dataset
import numpy as np
from natsort import natsorted

class LaplacianDataset(Dataset):
    def __init__(self, root):
        """
        Initializes the dataset with the directory containing .pt files.
        
        Parameters:
        directory (str): Path to the directory containing .pt files.
        """
        directory = root
        self.directory = directory
        # List of all .pt files in the directory
        self.files = natsorted([f for f in os.listdir(directory) if f.endswith('.pt')])

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.files)

    def __getitem__(self, index):
        """
        Loads and returns the data sample from file `index.pt`.
        
        Parameters:
        index (int): Index of the file to retrieve.
        
        Returns:
        torch.Tensor or dict: The data loaded from the .pt file.
        """
        # Construct the file path
        file_path = os.path.join(self.directory, self.files[index])
        # Load the .pt file (it could be a tensor or a dictionary of tensors)
        data = torch.load(file_path)
        x = data['x']
        y = data['y']
        return x,y
    
class LaplacianDatasetEdgeIndex(Dataset):
    def __init__(self, root):
        """
        Initializes the dataset with the directory containing .pt files.
        
        Parameters:
        directory (str): Path to the directory containing .pt files.
        """
        directory = root
        self.directory = directory
        # List of all .pt files in the directory
        self.files = natsorted([f for f in os.listdir(directory) if f.endswith('.pt')])
        y_ls = []
        for f in self.files:
            file_path = os.path.join(self.directory, f)
            data = torch.load(file_path, weights_only=False)
            y_ls.append(data['y'])
        self.y = torch.cat(y_ls)

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.files)

    def __getitem__(self, index):
        """
        Loads and returns the data sample from file `index.pt`.
        
        Parameters:
        index (int): Index of the file to retrieve.
        
        Returns:
        torch.Tensor or dict: The data loaded from the .pt file.
        """
        # Construct the file path
        file_path = os.path.join(self.directory, self.files[index])
        # Load the .pt file (it could be a tensor or a dictionary of tensors)
        data = torch.load(file_path)
        x = data['x']
        y = data['y']
        return x[:37], x[37:47],y, data['edge_index'] 
    
class LaplacianDatasetDegree(Dataset):
    def __init__(self, root):
        """
        Initializes the dataset with the directory containing .pt files.
        
        Parameters:
        directory (str): Path to the directory containing .pt files.
        """
        directory = root
        self.directory = directory
        # List of all .pt files in the directory
        self.files = natsorted([f for f in os.listdir(directory) if f.endswith('.pt')])
        y_ls = []
        for f in self.files:
            file_path = os.path.join(self.directory, f)
            data = torch.load(file_path, weights_only=False)
            y_ls.append(data['y'])
        self.y = torch.cat(y_ls)

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.files)

    def __getitem__(self, index):
        """
        Loads and returns the data sample from file `index.pt`.
        
        Parameters:
        index (int): Index of the file to retrieve.
        
        Returns:
        torch.Tensor or dict: The data loaded from the .pt file.
        """
        # Construct the file path
        file_path = os.path.join(self.directory, self.files[index])
        # Load the .pt file (it could be a tensor or a dictionary of tensors)
        data = torch.load(file_path)
        x = data['x']
        y = data['y']
        degree = data['degree']
        max_degree = data['max_degree']
        return x,y, degree, max_degree
    
    
class LaplacianDatasetMolPCBA(Dataset):
    def __init__(self, root):
        """
        Initializes the dataset with the directory containing .pt files.
        
        Parameters:
        directory (str): Path to the directory containing .pt files.
        """
        directory = root
        self.directory = directory
        # List of all .pt files in the directory
        self.files = natsorted([f for f in os.listdir(directory) if f.endswith('.npz')])

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.files)

    def __getitem__(self, index):
        """
        Loads and returns the data sample from file `index.pt`.
        
        Parameters:
        index (int): Index of the file to retrieve.
        
        Returns:
        torch.Tensor or dict: The data loaded from the .pt file.
        """
        # Construct the file path
        try:
            file_path = os.path.join(self.directory, self.files[index])
            # Load the .pt file (it could be a tensor or a dictionary of tensors)
            data = np.load(file_path)
            # data = torch.load(file_path)
            x = data['x']
            y = data['y']
            x = torch.tensor(x)
            y = torch.tensor(y)
            return x,y
        except:
            print('========hhhhheeeee======================')
            print(index)
            print(self.__len__())
            # print(e)
            
            
class LaplacianDatasetMoleSol(Dataset):
    def __init__(self, root):
        """
        Initializes the dataset with the directory containing .pt files.
        
        Parameters:
        directory (str): Path to the directory containing .pt files.
        """
        directory = root
        self.directory = directory
        # List of all .pt files in the directory
        self.files = natsorted([f for f in os.listdir(directory) if f.endswith('.pt')])

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.files)

    def __getitem__(self, index):
        """
        Loads and returns the data sample from file `index.pt`.
        
        Parameters:
        index (int): Index of the file to retrieve.
        
        Returns:
        torch.Tensor or dict: The data loaded from the .pt file.
        """
        # Construct the file path
        try:
            file_path = os.path.join(self.directory, self.files[index])
            # Load the .pt file (it could be a tensor or a dictionary of tensors)
            data = torch.load(file_path)
            # data = torch.load(file_path)
            x = data['x']
            y = data['y']
            # x = torch.tensor(x)
            # y = torch.tensor(y)
            return x,y
        except:
            print('========hhhhheeeee======================')
            print(index)
            print(self.__len__())
            # print(e)