# %%
from data_modules.toy_scm import ToySCMDataModule

data_module = ToySCMDataModule(data_dir="../Data",
                               dataset_name='triangle',
                               num_workers=0,
                               num_samples_tr=5000,
                               batch_size=500,
                               normalize='std',
                               one_hot=False,
                               equations_type='linear')

data_module.prepare_data()
deg = data_module.get_deg(indegree=True).float()
print(deg)
print(data_module.edge_dimension)

data_loader = data_module.train_dataloader()
scaler = data_module.scaler
batch = next(iter(data_loader))

# %%
from models.carefl.carefl import CAREFL

model = CAREFL(num_nodes=3,
               distr_z='laplace',
               flow_net_class='mlp',
               flow_architecture='spline',
               n_layers=1,
               n_hidden=1,
               parity=False,
               intervention_list=data_module.train_dataset.get_intervention_list(True),
               scaler=data_module.scaler,
               init=None)

# %%
X = batch.x.view(batch.num_graphs, -1)
_, prior_logprob, log_det = model.flow_model(X)

output = model.get_objective_metrics(data_loader=data_loader, name='test')

# %%

z, x, x_real = model.get_x_gener_distribution(data_loader)

# %%
z, x, x_real = model.get_recons_obs_distribution(data_loader)

# %%
x_gener_dict_out, x_real_dict_out = model.get_x_gener_I_distribution(data_loader,
                                                                     x_I={'x1': 2},
                                                                     use_aggregated_posterior=False)
# %%
x_gener_dict_cf, x_real_dict = model.get_x_cf_distribution(data_loader, x_I={'x1': 2}, is_noise=False)

# %%

output = model.my_test_dataloader(data_loader, name='test', save_dir='.')
