# %%
import torch

from data_modules.toy_scm import ToySCMDataModule

dm = ToySCMDataModule(dataset_name='triangle',
                      equations_type='linear',
                      batch_size=8,
                      num_samples_tr=5000)
dm.prepare_data()

dataset = dm.train_dataset

# dataset_lin.set_intervention(x_I={'x2': 0})
batch = next(iter(dm.train_dataloader()))

# batch.x_i = torch.cat([batch.x, batch.x_i], dim=1)

print(batch.x)
print(batch.edge_index)

data_loader = dm.test_dataloader()

# %% TEST VCAUSE MODULE
from models.vcause.vcause_module import VCAUSEModule

model = VCAUSEModule(x_dim=batch.x.shape[1],
                     h_dim_list_dec=[16],
                     h_dim_list_enc=[16],
                     z_dim=4,
                     m_layers=1,  # Number of layers for the message MLP of the decoder
                     lambda_kld=1.0,
                     deg=None,
                     edge_dim=batch.edge_attr.shape[1],
                     num_nodes=dm.num_nodes,
                     residual=0,  # Use resitual network in message passing
                     drop_rate=0.0,
                     drop_rate_i=0.0,
                     act_name='relu',
                     distr_x='normal',
                     distr_z='normal',
                     architecture='dgnn',
                     K=1)

# %%

qz_x = model.encoder(X=batch.x,
                     edge_index=batch.edge_index,
                     edge_attr=batch.edge_attr,
                     node_ids=batch.node_ids)

Z = qz_x.sample()
print(Z.shape)
# %%

px_z = model.decoder(Z=Z,
                     edge_index=batch.edge_index,
                     edge_attr=batch.edge_attr,
                     node_ids=batch.node_ids)

X_hat = px_z.sample()
print(X_hat.shape)

# %%
objective, info = model(data=batch,
                        estimator='elbo',
                        beta=1)
print(objective)

# %% TEST HETEROGENEOUS VCAUSE MODULE

from data_modules.real_scm import RealSCMDataModule

dm = RealSCMDataModule(dataset_name='german',
                       num_workers=0,
                       normalize='std',
                       batch_size=32,
                       num_samples_vl=0,
                       num_samples_ts=0,
                       equations_type='linear', )

dm.prepare_data()

dataset = dm.train_dataset

# dataset_lin.set_intervention(x_I={'x2': 0})
batch = next(iter(dm.train_dataloader()))

# batch.x_i = torch.cat([batch.x, batch.x_i], dim=1)

print(batch.x)
print(batch.edge_index)

data_loader = dm.test_dataloader()
# %%

from models.vcause.hvcause_module import HVCAUSEModule

model = HVCAUSEModule(x_dim_list=dm.num_features_list,
                      h_dim_list_dec=[16],
                      h_dim_list_enc=[16],
                      z_dim=4,
                      m_layers=1,  # Number of layers for the message MLP of the decoder
                      lambda_kld=1.0,
                      deg=None,
                      edge_dim=batch.edge_attr.shape[1],
                      num_nodes=dm.num_nodes,
                      residual=0,  # Use resitual network in message passing
                      drop_rate=0.0,
                      drop_rate_i=0.0,
                      act_name='relu',
                      distr_x_list=dm.likelihood_list,
                      distr_z='normal',
                      architecture='dgnn',
                      norm_categorical=False,
                      K=1)

# %%

qz_x = model.encoder(X=batch.x,
                     edge_index=batch.edge_index,
                     edge_attr=batch.edge_attr,
                     node_ids=batch.node_ids)

Z = qz_x.sample()
print(Z.shape)
# %%

px_z = model.decoder(Z=Z,
                     edge_index=batch.edge_index,
                     edge_attr=batch.edge_attr,
                     node_ids=batch.node_ids)

X_hat = px_z.sample()
print(X_hat.shape)

# %%
objective, info = model(data=batch,
                        estimator='elbo',
                        beta=1)
print(objective)
