# %%
from data_modules.toy_scm import ToySCMDataModule

data_module = ToySCMDataModule(data_dir = "../Data",
dataset_name='triangle',
num_workers=0,
            num_samples_tr  = 5000,
            batch_size = 5,
                               normalize='std',
            one_hot = False,
            equations_type = 'linear')

data_module.prepare_data()
deg = data_module.get_deg(indegree=True).float()
print(deg)
print(data_module.edge_dimension)

data_loader = data_module.train_dataloader()
scaler = data_module.scaler
batch = next(iter(data_loader))

# %%

x_factual = batch.x.view(batch.num_graphs, -1)
x_cf_real = data_loader.dataset.get_counterfactual(x_factual=x_factual, u_factual=batch.u.view(batch.num_graphs, -1),
                                                               x_I={'x2': 2})

print(f'x_factual_norm')
print(x_factual)
print(f'x_cf_real_norm')
print(scaler.transform(x_cf_real))

print(f'x_factual')
print(scaler.inverse_transform(x_factual))
print(f'x_cf_real')
print(x_cf_real)

# %%

from models.vgae.fvgae.cvae_multi_3 import MCVAE
topological_nodes=[0, 1, 2]
topological_parents=[[], [0], [0,1]]
model = MCVAE(x_dim=1,
                 h_dim_list_dec=[32],
                 h_dim_list_enc=[32],
                 z_dim=4,
                 lambda_kld=1.0,
                 drop_rate=0.1,
                 act_name='relu',
                 distr_x='delta',
                 distr_z='normal',
                 scaler = 0,
                 topological_nodes=topological_nodes,
                 topological_parents=topological_parents
              )



# %%
X = data.x.view(data.num_graphs, -1)

i = 2
x_i = X[:, topological_nodes[i]].unsqueeze(1)
if len(topological_parents[i]) == 0:
    pa_i = None
else:
    pa_i = X[:, topological_parents[i]]
objective, _ = model.cvae_list[i](x_i, estimator='elbo', cond_data=pa_i)
print(objective)

#%%
MCVAE.num_epochs_per_nodes = 10
print(MCVAE.is_training_node_i(MCVAE, 1, 19))