import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torch_geometric.datasets import TUDataset
from model import GraphProp
from node_attr import get_node_attribute_emb

class MyDataset2(Dataset):
    def __init__(self, name):
        device = torch.device('cpu')

        dataset = TUDataset(root='./data', name = name)
        description = ''
        with open(f'./data/{name}/raw/README.txt', 'r') as f:
            description = f.read()
        data = []
        self.labels = np.load(f'./data/{name}/after/labels.npy')

        model = torch.load('model.pt', map_location=device)
        data = torch.tensor(np.load(f'./data/{name}/after/data.npy'), dtype=torch.float32)
        emb = []
        for i in range(len(data)):
            emb.append(model.emb(data[i]))
        emb = torch.tensor(np.array(emb), dtype=torch.float32)
        print(emb)

        features = []
        for index, data in enumerate(dataset):
            if data.num_nodes <= 128:
                prompt = f"The dataset's description is: {description}. The node feature matrix is: {data.x.tolist()}。"
                emb1 = get_node_attribute_emb(prompt)
                features.append(emb1)

        features = torch.tensor(np.array(features), dtype=torch.float32)
        print(features)

        self.data = torch.cat((emb, features), 1)

    def __len__(self):
        return len(self.data)

    def input_dim(self):
        return self.data.shape[1]

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


def load_data2():
    name = 'PROTEINS'
    datset = MyDataset2(name)
    dataloader = dataloader(dataset, batch_size=128, shuffle=True)
    return dataloader
