

# %%

from data_modules.twitter_wall import TwitterWallDataModule

data_module = TwitterWallDataModule(data_dir = "../Data",
            scaler = 'power',
            T=30, L=10, Q=0, sigma=8,
            features = '11100000000',
            batch_size = 1)

# %%

deg1 = data_module.get_deg(indegree=True, bincount=False)
print(deg1.shape)
deg2 = data_module.get_deg(indegree=True, bincount=True)
print(deg2.shape)
# %%
i = 0
for data in iter(data_module.train_dataloader()):
    i +=1
print(i)
# %%

i = 0
for data in iter(data_module.train_dataloader()):
    print('data.x')
    print(data.x.max(0)[0])
    print(data.x.min(0)[0])
    print('\ndata.edge_attr')
    print(data.edge_attr.max(0)[0])
    print(data.edge_attr.min(0)[0])
    print('--------------')



# %%
from data_modules.twitter_se import TwitterSEDataModule

data_module = TwitterSEDataModule(data_dir = "../Data",
            batch_size = 512,
            balance = 'resample',
            scaler_name = 'power',
            features_s ='11111111111',
            features_e ='11111111111',
            features_l ='1111',
            split_se = True,
            omf = 1,
            ouw = 1,
            split_random = True)


# %%
for data in data_module.train_dataloader():
    x, y, s = data
    print(y.mean()*100)
