import os, gzip, re, bisect
import dgl
import dgl.function as fn
from dgl.transforms import add_self_loop
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

class VitgnnDataset(Dataset):
    def __init__(self, data_root,feature_root, namelist, args):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.label = args.label
        self.data_root = data_root
        self.feature_root = feature_root
        self.name_list = []
        list_file = open(os.path.join(data_root, namelist), mode='r')
        for line in list_file:
            if line.endswith('\n'):
                line = line[:-1]
            self.name_list.append(line)
        list_file.close()
        self.generate_Dataset()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def generate_Dataset(self):
        self.features = []
        self.labels = []
        print(self.label)
        if self.label == "thermal":
            for name in self.name_list:
                self.features.append(torch.tensor(np.load(os.path.join(self.feature_root, "features", name + ".npy"))))
                self.labels.append(torch.tensor(np.load(os.path.join(self.data_root, "hotmap", name + ".npy"))).unsqueeze(0))
            self.features = torch.stack(self.features)
            self.labels = torch.stack(self.labels)
        else:
            for name in self.name_list:
                self.features.append(torch.tensor(np.load(os.path.join(self.feature_root, "features", name+".npy"))))
                #print(torch.tensor(np.load(os.path.join(self.feature_root, "features", name+".npy"))).shape)
                self.labels.append(torch.tensor(np.load(os.path.join(self.feature_root, "labels", name+".npy"))))
            self.features = torch.stack(self.features)
            self.labels = torch.stack(self.labels)
            if self.label == "congestion":
                self.labels = self.labels[:,[0],:,:]
            if self.label == "drc":
                self.labels = self.labels[:,[1],:,:]
            if self.label == "ir_drop":
                self.labels = self.labels[:,[2],:,:]

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        feature = self.features[idx]
        label = self.labels[idx]
        return feature, label
