import torch
from torch_geometric.data import Data

def load_data(fileName):
    my_file = open(fileName, "r")
    data = my_file.read()
    data_list = data.split("\n")
    my_file.close()
    for i in range(len(data_list)):
        data_list[i] = list(map(int, data_list[i].split()))
    num_graphs = data_list[0][0]
    ans = []
    i = 1
    ptr = []
    tot_nodes = 0
    batch = []
    ctr = 0
    balance = 0
    tot = 0
    while(i < len(data_list)):
        ptr.append(tot_nodes)
        num_nodes = data_list[i][0]
        tot_nodes += num_nodes
        i += 1
        y = []
        x = []
        edge_index = [[], []]
        for j in range(i, i + num_nodes):
            batch.append(ctr)
            y.append(data_list[j][0])
            deg = data_list[j][1]
            x.append(deg)
            for k in range(deg):
                a = j-i
                b = data_list[j][k + 2]
                edge_index[0].append(a)
                edge_index[1].append(b)
        i += num_nodes
        ctr += 1
        balance += sum(y)
        tot += len(y)
        ans.append(Data(x=torch.tensor(x).unsqueeze(1).float(), batch=torch.tensor(batch), ptr=torch.tensor(ptr), edge_index=torch.tensor(edge_index), y=torch.tensor(y)))
    print(balance/tot, 'balance')
    return ans