In [29]:
from color_mnist_generator import color_mnist_generator
from utils import dataset_augmentation
from models import CSAE, CVAE, oracle

import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import to_categorical
In [30]:
img_train, y_test, num_labels_train, hue_labels_train, img_test, num_labels_test, hue_labels_test, img_test_cf, hue_labels_test_cf = color_mnist_generator()
In [34]:
def get_metrics(model):
    model.training = False

    #composition n=1
    comp_img = model.composition(np.concatenate((num_labels_test, hue_labels_test), axis=1), img_test, 1)
    l1_comp = 255*np.mean(np.abs(comp_img-img_test))
    print("composition MAE: ", l1_comp)

    #number reversibility
    np.random.seed(2)
    list_numbers = []
    for i in range(len(y_test)):
      while True:
        candidate = np.random.randint(10)
        if candidate != y_test[i]:
          list_numbers.append(candidate)
          break
    label_num_cf = to_categorical(list_numbers, num_classes=10)
    reversed = model.reversibility(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((label_num_cf, hue_labels_test), axis=1), img_test, 1)
    l1_num_rev = 255*np.mean(np.abs(reversed-img_test)) 
    print("number reversibility: ", l1_num_rev)

    #hue reversibility
    np.random.seed(2)
    hue_cf = np.random.random(len(img_test)).reshape(-1,1)
    reversed = model.reversibility(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((num_labels_test, hue_cf), axis=1), img_test, 1)
    l1_hue_rev = 255*np.mean(np.abs(reversed-img_test))
    print("hue reversibility: ", l1_hue_rev)

    #digit intervention
    print("digit intervention")
    label_num_cf = to_categorical(list_numbers, num_classes=10)
    img_cf = model.cf_generation(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((label_num_cf, hue_labels_test), axis=1), img_test)
    pred = oracle_number(img_cf)
    m = tf.keras.metrics.CategoricalAccuracy(name='categorical_accuracy')
    m.update_state(label_num_cf, pred)
    num_eff1 = m.result().numpy()
    print("number effectiveness: ", num_eff1)
    pred = oracle_hue(img_cf)
    hue_eff1 = np.mean(np.abs(pred-hue_labels_test))
    print("hue effectiveness: ", hue_eff1)
    
    #hue internevtion
    print("hue intervention")
    np.random.seed(2)
    hue_cf = np.random.random(len(img_test)).reshape(-1,1)
    img_cf = model.cf_generation(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((num_labels_test, hue_cf), axis=1), img_test)
    pred = oracle_hue(img_cf)
    hue_eff2 = np.mean(np.abs(pred-hue_cf))
    print("hue effectiveness: ", hue_eff2)
    pred = oracle_number(img_cf)
    m = tf.keras.metrics.CategoricalAccuracy(name='categorical_accuracy')
    m.update_state(num_labels_test, pred)
    num_eff2 = m.result().numpy()
    print("number effectiveness: ", num_eff2)
    
    # CF ground truth distance with hue
    hue_cf = hue_labels_test_cf
    cf_estimation = model.cf_generation(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((num_labels_test, hue_cf), axis=1), img_test)
    l1_cf = 255*np.mean(np.abs(cf_estimation-img_test_cf))
    print("CF ground truth: ", l1_cf)
    
    #return l1_comp, l1_num_rev, l1_hue_rev, num_eff1, hue_eff1, num_eff2, hue_eff2, l1_cf

Train Oracles¶

In [4]:
augmented_train, num_labels_train_augmented, hue_labels_train_augmented = dataset_augmentation(img_train, num_labels_train, hue_labels_train)
In [16]:
oracle_hue = oracle(target="hue")
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
callback = tf.keras.callbacks.EarlyStopping(monitor='val_mae', restore_best_weights=True, patience=40)
oracle_hue.compile(optimizer, loss="mse", metrics="mae")
batch_size = 64
with tf.device("/device:GPU:0"):
  history = oracle_hue.fit(x = augmented_train, y=hue_labels_train_augmented, validation_data = [img_test, hue_labels_test], epochs=250, batch_size=batch_size, verbose=0, callbacks=[callback])
In [17]:
oracle_number = oracle(target="number")
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
callback = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', restore_best_weights=True, patience=40)
oracle_number.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(), metrics="accuracy")
batch_size = 64
with tf.device("/device:GPU:0"):
  history = oracle_number.fit(x = augmented_train, y=num_labels_train_augmented, validation_data = [img_test, num_labels_test], epochs=250, batch_size=batch_size, verbose=0, callbacks = [callback])

CSAE¶

In [20]:
latent_dim = 16
Lambda = 0.015
batch_size = 64
n = 9
model = CSAE(latent_dim = latent_dim, Lambda=Lambda)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer, loss=model.loss_, metrics=[model.reconstruction, model.regularization])
with tf.device("/device:GPU:0"):
    history = model.fit(x =[np.concatenate((num_labels_train, hue_labels_train), axis=1), img_train], y=img_train, validation_data = [[np.concatenate((num_labels_test, hue_labels_test), axis=1), img_test], img_test], epochs=200, batch_size=batch_size, verbose=0)
In [35]:
get_metrics(model)
composition MAE:  3.1012444384396076
number reversibility:  4.507591724395752
hue reversibility:  3.924532486125827
digit intervention
number effectiveness:  0.9168
hue effectiveness:  0.0023764025
hue intervention
hue effectiveness:  0.0021924477
number effectiveness:  0.9932
CF ground truth:  3.103256905451417

Model¶

In [36]:
latent_dim = 16
batch_size = 64
model = CVAE(latent_dim = latent_dim, kl_weight=0.6)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer, loss=model.loss_, metrics=[model.reconstruction, model.kl, model.var])
with tf.device("/device:GPU:0"):
    history = model.fit(x =[np.concatenate((num_labels_train, hue_labels_train), axis=1), img_train], y=img_train, validation_data = [[np.concatenate((num_labels_test, hue_labels_test), axis=1), img_test], img_test], epochs=200, batch_size=batch_size, verbose=0)
In [37]:
get_metrics(model)
composition MAE:  4.057230865582824
number reversibility:  5.2525602746754885
hue reversibility:  4.750587260350585
digit intervention
number effectiveness:  0.9867
hue effectiveness:  0.0052283737
hue intervention
hue effectiveness:  0.0053906273
number effectiveness:  0.9961
CF ground truth:  4.083548253402114