# %%
from datasets.triangle import TriangleSCM
from datasets.transforms import ToTensor
import torch
dataset_lin = TriangleSCM(equations_type='linear',transform=ToTensor())
dataset_lin.prepare_data(1000, add_self_loop=True)

# dataset_lin.set_intervention(x_I={'x2': 0})
batch = dataset_lin.__getitem__(0)

# batch.x_i = torch.cat([batch.x, batch.x_i], dim=1)

print(batch.x)
print(batch.edge_index)



# %% TEST DROPOUT IN ADJACENCY MATRIX


from utils.dropout import *

print(batch.edge_index)
print(batch.edge_attr)

edge_index, edge_attr = dropout_adj(batch.edge_index, batch.edge_attr, p=0.6, keep_self_loops=False)
print(edge_index)
print(edge_attr)

# Dropout all parents
edge_index, edge_attr = dropout_adj_parents(batch.edge_index, batch.edge_attr, p=1.0, prob_keep_self=0.0)
print(edge_index)
print(edge_attr)

# Dropout self-loops
edge_index, edge_attr = dropout_adj_parents(batch.edge_index, batch.edge_attr, p=1.0, prob_keep_self=1.0)
print(edge_index)
print(edge_attr)