import os.path as osp

import torch
import warnings
from torch_geometric.data import Dataset

class NetlistDataset(Dataset):
    """ This class extends the torch_geometric.data.Dataset class to handle netlist data.
    """
    
    def __init__(self, root, processed_file=None, transform=None, pre_transform=None, pre_filter=None):
        """Constructor method for the NetlistDataset class.
        
        :root: The root directory for the dataset.
        :processed_file: The name of the file to save the processed data to.
        :transform: A function that dynamically transforms the data object before accessing.
        :pre_transform: A function to apply to the data before saving.
        :pre_filter: A function manually filter out data objects before saving.
        """
        super().__init__(root, transform, pre_transform, pre_filter)
        file_path = root+"/"+processed_file if processed_file else root+"/"+self.processed_paths[0]
        if osp.exists(file_path):
            print(f"Found dataset at {file_path}")
            self.data, self.label = torch.load(file_path)   
        else:
            warnings.warn(f"File {file_path} not found, no data or labels loaded.")
            self.data = []
            self.label = []

    @property
    def processed_file_names(self):
        return ['data_netlist.pt']
    
    def len(self):
        return len(self.data)

    def get(self, index):
        """
        Returns sample and label at an index. Applies transformation if it exists.
        
        :index: The index of the data to return.
        """
        sample = self.data[index]
        label = self.label[index]

        if self.transform:
            sample = self.transform(sample)

        return sample, label
    
    def add_item(self, data, label):
        """
        Adds data to the dataset.
        
        :data: The data to add.
        :label: The label for the data.
        """
        self.data.append(data)    
        self.label.append(label)
    
    def process_new_data(self, file=None):
        """
        Save data to a file.
        
        :data: The data to add.
        :label: The label for the data.
        """

        if file:
            torch.save((self.data, self.label), file)
        else:
            warnings.warn(f"Will now override existing dataset because no new file was provided.")
            torch.save((self.data, self.label), self.processed_paths[0])

