# %%
import torch

from data_modules.toy_scm import ToySCMDataModule

dm = ToySCMDataModule(dataset_name='triangle',
                      equations_type='linear',
                      batch_size=8,
                      num_samples_tr=5000)
dm.prepare_data()

dataset = dm.train_dataset

# dataset_lin.set_intervention(x_I={'x2': 0})
batch = next(iter(dm.train_dataloader()))

# batch.x_i = torch.cat([batch.x, batch.x_i], dim=1)

print(batch.x)
print(batch.edge_index)

# %% TEST  BLOCKS: PNA

from modules.blocks.pna import PNAConv

aggregators = ['sum', 'min', 'max', 'std']
scalers = ['identity', 'amplification', 'attenuation']
model = PNAConv(in_channels=batch.x.shape[1],
                out_channels=16,
                aggregators=aggregators,
                scalers=scalers,
                deg=dm.get_deg(indegree=True),
                edge_dim=batch.edge_attr.shape[1],
                towers=1,
                pre_layers=1,
                post_layers=1,
                divide_input=False)

out = model(x=batch.x,
            edge_index=batch.edge_index,
            edge_attr=batch.edge_attr)
print(out.shape)
# %% TEST  BLOCKS: DISJOINT GRAPH CONVOLUTION
from modules.blocks.disjoint_graph_conv import DisjointGConv

aggregators = ['sum', 'min', 'max', 'std']
scalers = ['identity', 'amplification', 'attenuation']
model = DisjointGConv(m_channels=[batch.x.shape[1], 16, 8],
                      edge_dim=batch.edge_attr.shape[1],
                      aggr='add',
                      act_name='relu',
                      drop_rate=0.0,
                      drop_rate_i=0.0,
                      use_i_in_message_ij=True)

out = model(x=batch.x,
            edge_index=batch.edge_index,
            edge_attr=batch.edge_attr)
print(out.shape)

# %% TEST  BLOCKS: DISJOINT PNA

from modules.blocks.pna.disjoint_pna import DisjointPNAConv

aggregators = ['sum', 'min', 'max', 'std']
scalers = ['identity', 'amplification', 'attenuation']
model = DisjointPNAConv(m_channels=[batch.x.shape[1], 16, 8],
                aggregators=aggregators,
                scalers=scalers,
                deg=dm.get_deg(indegree=True),
                edge_dim=batch.edge_attr.shape[1],
                        num_nodes=dm.num_nodes,
                        act_name='relu',
                        drop_rate=0.0)
out = model(x=batch.x,
            edge_index=batch.edge_index,
            edge_attr=batch.edge_attr,
            node_ids=batch.node_ids)
print(out.shape)


# %% TEST  MODULES: PNA

# --------------------------------------------------------------------#
#                            MODULES                                  #
# --------------------------------------------------------------------#
from modules.pna import PNAModule

model = PNAModule(c_list=[batch.x.shape[1], 16, 16, 8],
                  deg=dm.get_deg(indegree=True),
                  edge_dim=batch.edge_attr.shape[1],
                  drop_rate=0.1,
                  act_name='relu',
                  aggregators=None,
                  scalers=None,
                  residual=False)

out = model(x=batch.x,
            edge_index=batch.edge_index,
            edge_attr=batch.edge_attr)
print(out.shape)

# %% TEST  MODULES: DISJOINT GNN

from modules.disjoint_gnn import DisjointGNN

model = DisjointGNN(c_list=[batch.x.shape[1], 16, 16, 8],
                    m_layers=1,
                    edge_dim=batch.edge_attr.shape[1],
                    num_nodes=dm.num_nodes,
                    drop_rate=0.1,
                    drop_rate_i=0.0,
                    residual=0,
                    act_name='relu',
                    use_i_in_message_ij=True,
                    aggr='add')

out = model(batch.x, batch.edge_index, batch.edge_attr, node_ids=batch.node_ids)
print(out.shape)

# %% TEST  MODULES: DISJOINT GNN

from modules.disjoint_pna import DisjointPNA

model = DisjointPNA(c_list=[batch.x.shape[1], 16, 16, 8],
                    m_layers=1,
                    edge_dim=batch.edge_attr.shape[1],
                    deg=dm.get_deg(indegree=True),
                    num_nodes=dm.num_nodes,
                    aggregators=None,
                    scalers=None,
                    drop_rate=0.1,
                    residual=0,
                    act_name='relu')

out = model(batch.x, batch.edge_index, batch.edge_attr, node_ids=batch.node_ids)
print(out.shape)