# %%
import numpy as np

from aif360.datasets import GermanDataset

# %%

dataset = GermanDataset(protected_attribute_names=['sex'])

# fix labels to be 0s and 1s
# originally, 2 = bad credit, 1 = good credit
dataset.labels = np.where(dataset.labels == 2, 0, 1)

dataset.unfavorable_label = 0.0
dataset.metadata['protected_attribute_maps']

df = dataset.convert_to_dataframe()[0]

# %%
# list(df.columns)
columns_interest = ['sex',  # A
                    'age',  # C
                    'credit_amount',  # R:
                    'month',  # R repayment duration
                    'housing=A151', 'housing=A152', 'housing=A153',  # S
                    'savings=A61', 'savings=A62', 'savings=A63', 'savings=A64', 'savings=A65',  # S
                    'status=A11', 'status=A12', 'status=A13', 'status=A14'  ## S status of cheking account
                    ]
X = df[columns_interest]
X.head()
list(df.columns)

Y = df['credit']
Y.value_counts()  #

# %%

from datasets.german import GermanSCM
from datasets.transforms import ToTensor

dataset = GermanSCM(X=X, Y=Y, transform=ToTensor())

dataset.prepare_data()

data = dataset.__getitem__(0)

# %%

from data_modules.real_scm import RealSCMDataModule

data_module = RealSCMDataModule(data_dir="../Data",
                                dataset_name='german',
                                num_workers=0,
                                batch_size=2,
                                normalize='std',
                                one_hot=False)


data_module.prepare_data()


valid_loader = data_module.val_dataloader()

data = next(iter(valid_loader))


# %%

from models.vgae.fvgae.hfvgae_module import HFVGAEModule

model = HFVGAEModule(x_dim_list=data_module.num_features_list, # [[2], [3,4], [3,4,5]], e.g. [node_1, node_2, node_3] Should be in the same order as in the dataset
                 h_dim_list_dec=[16],
                 h_dim_list_enc=[16],
                 z_dim=4,
                 m_layers=1,  # Number of layers for the message MLP of the decoder
                 lambda_kld=0.05,
                 deg=None,
                 edge_dim=data_module.edge_dimension,
                 edge_dim_ancestors=None,
                 num_nodes=data_module.num_nodes,
                 residual=0,  # Use resitual network in message passing
                 update=0, # Use update network in message passing
                 drop_rate=0.0,
                 drop_rate_i=0.0,
                 act_name='relu',
                 distr_x_list=data_module.likelihood_list, # Should be the same size as x_dim_list
                 distr_z='normal',
                 architecture='mygcnv2',
                 K=3)


# %%


qz_x = model.encoder(X=data.x,
                  edge_index=data.edge_index,
                  edge_attr=data.edge_attr,
                  return_mean=False,
                  node_ids=data.node_ids)

z = qz_x.sample()
# %%

px_z = model.decoder(z, data.edge_index, edge_attr=data.edge_attr, node_ids=data.node_ids)

x_hat = px_z.sample()

log_prob = px_z.log_prob(model.get_x_graph(data))



# %%

from models.vgae.fvgae.fvgae import FVGAE

model = FVGAE(x_dim=data_module.num_features_list,
              h_dim_list_dec=[16],
              h_dim_list_enc=[16],
              z_dim=1,
              m_layers=1,
              deg=None,  # Only PNA architecture
              edge_dim=data_module.edge_dimension,
              edge_dim_ancestors=None,
              num_nodes=data_module.num_nodes,
              beta=1.0,
              lambda_kld=1.0,
              annealing_beta=False,
              residual=0,  # Only PNA architecture
              drop_rate=0.0,
              drop_rate_i=0.0,
              dropout_adj_rate=0.0,
              dropout_adj_I_rate=0.0,
              dropout_adj_I_prob_keep_self=0.0,
              keep_self_loops=True,
              dropout_input_rate=0.0,
              dropout_adj_T=0,
              act_name='relu',
              distr_x=data_module.likelihood_list,
              distr_z='normal',
              architecture='mygcnv2',  # PNA, MyGCN
              estimator='iwae',
              K=3,  # Only for IWAE estimator
              intervention_list_in=None,  # In distribution interventions
              intervention_list_out=None,  # Out of distribution interventions
              scaler=data_module.scaler,
              init=None,
              data_is_toy=True,
              is_heterogeneous=True)


# %%


out = model.validation_step(data, 1)


# %%

output = model.get_objective_metrics(data_loader=valid_loader, name='valid')