# %%

# %%
from datasets.chain import ChainSCM
from datasets.utils import normalize_adj
import torch


dataset = ChainSCM()


dataset.prepare_data(n_samples=1000, normalize_A=None,  add_self_loop=True)

data = dataset.__getitem__(0)

print(f"Edge index:\n{data.edge_index}")
print(f"Edge attr:\n{data.edge_attr}")
# print(f"SCM_adj:\n{dataset.SCM_adj}")
# print(f"SCM_adj:\n{normalize_adj(dataset.SCM_adj, 'col')}")


# %%
from datasets.triangle import TriangleSCM
from datasets.transforms import ToTensor
from datasets.utils import normalize_adj
import torch


dataset = TriangleSCM(transform=ToTensor())


dataset.prepare_data(n_samples=1000, normalize_A=None,  add_self_loop=True)

data = dataset.__getitem__(0)

print(f"Edge index:\n{data.edge_index}")
print(f"Edge attr:\n{data.edge_attr}")


# %% Counterfactual

x_cf = dataset.get_counterfactual(x_factual=data.x.reshape(1,-1),
                                  u_factual=data.u.reshape(1,-1),
                                  x_I={'x3': 2})

print(data.x.reshape(1,-1))
print(x_cf)

# %% Set intervention data

dataset.clean_intervention()
dataset.set_intervention(x_I={'x1': 2})
data = dataset.__getitem__(0)
print(f"x:\n{data.x}")
print(f"x_i:\n{data.x_i}")
print(f"Edge index:\n{data.edge_index_i}")
print(f"Edge attr:\n{data.edge_attr_i}")
print(f"SCM_adj_i:\n{dataset.SCM_adj_i}")

x_i =dataset.sample_intervention(x_I={'x1': 2}, n_samples=200000)
print(x_i.mean(0))

# %% Sample observational and  intervention data

x_o =dataset.sample(n_samples=200000).T
x_i =dataset.sample_intervention(x_I={'x2': 30}, n_samples=200000)


print(f'Mean o: {x_o.mean(0)}')
print(f'Std o: {x_o.std(0)}')
print(f'Mean i: {x_i.mean(0)}')
print(f'Std i: {x_i.std(0)}')

dataset.clean_intervention()
x_o =dataset.sample(n_samples=200000).T
print(f'Mean o: {x_o.mean(0)}')
print(f'Std o: {x_o.std(0)}')
# %% Variational Graph Autoencoder for Features



from data_modules.toy_scm3 import ToySCM3DataModule

data_module = ToySCM3DataModule(data_dir = "../Data",
            num_samples_tr  = 10000,
            batch_size = 2,
            one_hot = False,
            equations_type = 'linear')

data_module.prepare_data()
deg = data_module.get_deg(indegree=True)
print(deg)
print(data_module.edge_dimension)

data = next(iter(data_module.train_dataloader()))
# %%

from models.vgae import FVGAE
import torch.nn as nn
import torch
model = FVGAE(x_dim=1,
              h_dim_list=[4],
              z_dim=1,
              deg=deg,
                 edge_dim=data_module.edge_dimension,
              residual=0,
                drop_rate=0.1,
                act_name='relu',
                distr_x='delta',
                distr_z='normal',
                architecture='pna')


def init_weights(m):
    if type(m) == nn.Linear:
        m.weight.data.uniform_(-1.0, 1.0)


model.apply(init_weights)
# %%

model.eval()

my_dataset = data_module.train_dataset
data = my_dataset.__getitem__(0)

x = data.x.detach().clone()
qz_x = model.encoder(x, data.edge_index, data.edge_attr)
z = qz_x.mean
print(z)



x_hat  = model.decoder_module(z, data.edge_index, data.edge_attr)


# %% Check intervention on x1

my_dataset.set_intervention(x_I={'x1': 10})
data = my_dataset.__getitem__(0)

qz_x = model.encoder(data.x, data.edge_index, data.edge_attr)
z_i = qz_x.mean
my_dataset.clean_intervention()
assert z[0] != z_i[0]
assert z[1] != z_i[1]
assert z[2] != z_i[2]

x_hat_i  = model.decoder_module(z_i, data.edge_index, data.edge_attr)
assert x_hat[0] != x_hat_i[0]
assert x_hat[1] != x_hat_i[1]
assert x_hat[2] != x_hat_i[2]
# %% Check intervention on x2

my_dataset.set_intervention(x_I={'x2': 10})
data = my_dataset.__getitem__(0)

qz_x = model.encoder(data.x, data.edge_index, data.edge_attr)
z_i = qz_x.mean
my_dataset.clean_intervention()
assert z[0] == z_i[0]
assert z[1] != z_i[1]
assert z[2] != z_i[2]

x_hat_i  = model.decoder_module(z_i, data.edge_index, data.edge_attr)
assert x_hat[0] == x_hat_i[0]
assert x_hat[1] != x_hat_i[1]
assert x_hat[2] != x_hat_i[2]



# %% Check intervention on x3

my_dataset.set_intervention(x_I={'x3': 10})
data = my_dataset.__getitem__(0)

qz_x = model.encoder(data.x, data.edge_index, data.edge_attr)
z_i = qz_x.mean

my_dataset.clean_intervention()

assert z[0] == z_i[0]
assert z[1] == z_i[1]
assert z[2] != z_i[2]


x_hat_i  = model.decoder_module(z_i, data.edge_index, data.edge_attr)
assert x_hat[0] == x_hat_i[0]
assert x_hat[1] == x_hat_i[1]
assert x_hat[2] != x_hat_i[2]

