from color_mnist_generator import color_mnist_generator
from utils import dataset_augmentation
from models import CSAE, CVAE, oracle
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import to_categorical
img_train, y_test, num_labels_train, hue_labels_train, img_test, num_labels_test, hue_labels_test, img_test_cf, hue_labels_test_cf = color_mnist_generator()
def get_metrics(model):
model.training = False
#composition n=1
comp_img = model.composition(np.concatenate((num_labels_test, hue_labels_test), axis=1), img_test, 1)
l1_comp = 255*np.mean(np.abs(comp_img-img_test))
print("composition MAE: ", l1_comp)
#number reversibility
np.random.seed(2)
list_numbers = []
for i in range(len(y_test)):
while True:
candidate = np.random.randint(10)
if candidate != y_test[i]:
list_numbers.append(candidate)
break
label_num_cf = to_categorical(list_numbers, num_classes=10)
reversed = model.reversibility(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((label_num_cf, hue_labels_test), axis=1), img_test, 1)
l1_num_rev = 255*np.mean(np.abs(reversed-img_test))
print("number reversibility: ", l1_num_rev)
#hue reversibility
np.random.seed(2)
hue_cf = np.random.random(len(img_test)).reshape(-1,1)
reversed = model.reversibility(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((num_labels_test, hue_cf), axis=1), img_test, 1)
l1_hue_rev = 255*np.mean(np.abs(reversed-img_test))
print("hue reversibility: ", l1_hue_rev)
#digit intervention
print("digit intervention")
label_num_cf = to_categorical(list_numbers, num_classes=10)
img_cf = model.cf_generation(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((label_num_cf, hue_labels_test), axis=1), img_test)
pred = oracle_number(img_cf)
m = tf.keras.metrics.CategoricalAccuracy(name='categorical_accuracy')
m.update_state(label_num_cf, pred)
num_eff1 = m.result().numpy()
print("number effectiveness: ", num_eff1)
pred = oracle_hue(img_cf)
hue_eff1 = np.mean(np.abs(pred-hue_labels_test))
print("hue effectiveness: ", hue_eff1)
#hue internevtion
print("hue intervention")
np.random.seed(2)
hue_cf = np.random.random(len(img_test)).reshape(-1,1)
img_cf = model.cf_generation(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((num_labels_test, hue_cf), axis=1), img_test)
pred = oracle_hue(img_cf)
hue_eff2 = np.mean(np.abs(pred-hue_cf))
print("hue effectiveness: ", hue_eff2)
pred = oracle_number(img_cf)
m = tf.keras.metrics.CategoricalAccuracy(name='categorical_accuracy')
m.update_state(num_labels_test, pred)
num_eff2 = m.result().numpy()
print("number effectiveness: ", num_eff2)
# CF ground truth distance with hue
hue_cf = hue_labels_test_cf
cf_estimation = model.cf_generation(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((num_labels_test, hue_cf), axis=1), img_test)
l1_cf = 255*np.mean(np.abs(cf_estimation-img_test_cf))
print("CF ground truth: ", l1_cf)
#return l1_comp, l1_num_rev, l1_hue_rev, num_eff1, hue_eff1, num_eff2, hue_eff2, l1_cf
augmented_train, num_labels_train_augmented, hue_labels_train_augmented = dataset_augmentation(img_train, num_labels_train, hue_labels_train)
oracle_hue = oracle(target="hue")
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
callback = tf.keras.callbacks.EarlyStopping(monitor='val_mae', restore_best_weights=True, patience=40)
oracle_hue.compile(optimizer, loss="mse", metrics="mae")
batch_size = 64
with tf.device("/device:GPU:0"):
history = oracle_hue.fit(x = augmented_train, y=hue_labels_train_augmented, validation_data = [img_test, hue_labels_test], epochs=250, batch_size=batch_size, verbose=0, callbacks=[callback])
oracle_number = oracle(target="number")
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
callback = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', restore_best_weights=True, patience=40)
oracle_number.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(), metrics="accuracy")
batch_size = 64
with tf.device("/device:GPU:0"):
history = oracle_number.fit(x = augmented_train, y=num_labels_train_augmented, validation_data = [img_test, num_labels_test], epochs=250, batch_size=batch_size, verbose=0, callbacks = [callback])
latent_dim = 16
Lambda = 0.015
batch_size = 64
n = 9
model = CSAE(latent_dim = latent_dim, Lambda=Lambda)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer, loss=model.loss_, metrics=[model.reconstruction, model.regularization])
with tf.device("/device:GPU:0"):
history = model.fit(x =[np.concatenate((num_labels_train, hue_labels_train), axis=1), img_train], y=img_train, validation_data = [[np.concatenate((num_labels_test, hue_labels_test), axis=1), img_test], img_test], epochs=200, batch_size=batch_size, verbose=0)
get_metrics(model)
composition MAE: 3.1012444384396076 number reversibility: 4.507591724395752 hue reversibility: 3.924532486125827 digit intervention number effectiveness: 0.9168 hue effectiveness: 0.0023764025 hue intervention hue effectiveness: 0.0021924477 number effectiveness: 0.9932 CF ground truth: 3.103256905451417
latent_dim = 16
batch_size = 64
model = CVAE(latent_dim = latent_dim, kl_weight=0.6)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer, loss=model.loss_, metrics=[model.reconstruction, model.kl, model.var])
with tf.device("/device:GPU:0"):
history = model.fit(x =[np.concatenate((num_labels_train, hue_labels_train), axis=1), img_train], y=img_train, validation_data = [[np.concatenate((num_labels_test, hue_labels_test), axis=1), img_test], img_test], epochs=200, batch_size=batch_size, verbose=0)
get_metrics(model)
composition MAE: 4.057230865582824 number reversibility: 5.2525602746754885 hue reversibility: 4.750587260350585 digit intervention number effectiveness: 0.9867 hue effectiveness: 0.0052283737 hue intervention hue effectiveness: 0.0053906273 number effectiveness: 0.9961 CF ground truth: 4.083548253402114