# %%
from data_modules.toy_scm import ToySCMDataModule

data_module = ToySCMDataModule(data_dir = "../Data",
            dataset_name='law',
            num_workers=0,
            num_samples_tr  = 5000,
            batch_size = 5, #5000 for classifier
            normalize='std',
            one_hot = False,
            equations_type = 'linear')

data_module.prepare_data()
deg = data_module.get_deg(indegree=True).float()
print(deg)
print(data_module.edge_dimension)

data_loader = data_module.train_dataloader()

scaler = data_module.scaler
# batch = next(iter(data_loader))
# X_all = batch.x.view(batch.num_graphs, -1)
# Y = batch.y

#unnormalized
X_full = data_module.train_dataset.X
U = data_module.train_dataset.U
Y = data_module.train_dataset.Y.ravel()
#%%
attributes_dict = data_module.attributes_dict
print(attributes_dict)

# %%
# prepare X
mask_unaware = attributes_dict['fair_attributes'] \
               + attributes_dict['unfair_attributes']


mask_fair = attributes_dict['fair_attributes']


X_unaware = X_full[:, mask_unaware]
X_fair = X_full[:, mask_fair]



#%%
from sklearn.linear_model import LogisticRegression
lr_full = LogisticRegression()
lr_full.fit(X_full, Y)

lr_unaware = LogisticRegression()
lr_unaware.fit(X_unaware, Y)

lr_fair = LogisticRegression()
lr_fair.fit(X_fair, Y)


#%%
x_cf_real_x1_0 = data_loader.dataset.get_counterfactual(x_factual=X_full, u_factual=U,
                                                               x_I={'x1': 0})
x_cf_real_x1_1 = data_loader.dataset.get_counterfactual(x_factual=X_full, u_factual=U,
                                                               x_I={'x1': 1})
#%%
y_cf_full_x1_0 = lr_full.predict(x_cf_real_x1_0)
y_cf_unaware_x1_0 = lr_unaware.predict(x_cf_real_x1_0[:, mask_unaware])
y_cf_fair_x1_0 = lr_fair.predict(x_cf_real_x1_0[:, mask_fair])

y_cf_full_x1_1 = lr_full.predict(x_cf_real_x1_1)
y_cf_unaware_x1_1 = lr_unaware.predict(x_cf_real_x1_1[:, mask_unaware])
y_cf_fair_x1_1 = lr_fair.predict(x_cf_real_x1_1[:, mask_fair])

#%%
from sklearn.metrics import mean_squared_error
import numpy as np
cf_unfair_full_x1 = np.sqrt(mean_squared_error(y_cf_full_x1_0, y_cf_full_x1_1))
cf_unfair_unaware_x1 = np.sqrt(mean_squared_error(y_cf_unaware_x1_0, y_cf_unaware_x1_1))
cf_unfair_fair_x1 = np.sqrt(mean_squared_error(y_cf_fair_x1_0, y_cf_fair_x1_1))

print(cf_unfair_full_x1 )
print(cf_unfair_unaware_x1)
print(cf_unfair_fair_x1)