import numpy as np
from tensorflow.keras.utils import to_categorical


def augmented_image(img, translation):
    copy_img = img.copy()
    padded_img = np.concatenate((np.zeros([28,3,3]), copy_img, np.zeros([28,3,3])), axis=1)
    padded_img = np.concatenate((np.zeros([3,34,3]), padded_img, np.zeros([3,34,3])), axis=0) #shape 34x34x3
    #translation elements between -3 and 3
    t_v, t_h = translation[0]+3, translation[1]+3  #vertical and horizontal translation
    new_img = padded_img[t_v:t_v+28, t_h:t_h+28,:]
    return new_img.astype("float32")

def dataset_augmentation(img_train, num_labels_train, hue_labels_train):
    augmented_train = []
    for i in range(len(img_train)):
        img = img_train[i].copy()
        augmented_train.append(img)
        translation_list = []  #we save translations here to ensure that they are no repeated
        for j in range(3):
            while True:
                x, y = np.random.randint(-3,4), np.random.randint(-3,4)  #x and y are the amount of horizontal and vertival translation
                if not (x == y == 0) and ([x,y] not in translation_list):
                    break
            translation = [x,y]
            translation_list.append(translation)
            new_img = augmented_image(img, translation)
            augmented_train.append(new_img)
            
    augmented_train = np.array(augmented_train)
    hue_labels_train_augmented = np.repeat(hue_labels_train, 4, axis=0)
    num_labels_train_augmented = np.repeat(num_labels_train, 4, axis=0)
    
    return augmented_train, num_labels_train_augmented, hue_labels_train_augmented


def get_metrics(model):
    model.training = False

    #composition n=1
    comp_img = model.composition(np.concatenate((num_labels_test, hue_labels_test), axis=1), img_test, 1)
    l1_comp = 255*np.mean(np.abs(comp_img-img_test))
    print("composition MAE: ", l1_comp)

    #number reversibility
    np.random.seed(2)
    list_numbers = []
    for i in range(len(y_test)):
      while True:
        candidate = np.random.randint(10)
        if candidate != y_test[i]:
          list_numbers.append(candidate)
          break
    label_num_cf = to_categorical(list_numbers, num_classes=10)
    reversed = model.reversibility(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((label_num_cf, hue_labels_test), axis=1), img_test, 1)
    l1_num_rev = 255*np.mean(np.abs(reversed-img_test)) 
    print("number reversibility: ", l1_num_rev)

    #hue reversibility
    np.random.seed(2)
    hue_cf = np.random.random(len(img_test)).reshape(-1,1)
    reversed = model.reversibility(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((num_labels_test, hue_cf), axis=1), img_test, 1)
    l1_hue_rev = 255*np.mean(np.abs(reversed-img_test))
    print("hue reversibility: ", l1_hue_rev)

    #digit intervention
    print("digit intervention")
    label_num_cf = to_categorical(list_numbers, num_classes=10)
    img_cf = model.cf_generation(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((label_num_cf, hue_labels_test), axis=1), img_test)
    pred = oracle_number(img_cf)
    m = tf.keras.metrics.CategoricalAccuracy(name='categorical_accuracy')
    m.update_state(label_num_cf, pred)
    num_eff1 = m.result().numpy()
    print("number effectiveness: ", num_eff1)
    pred = oracle_hue(img_cf)
    hue_eff1 = np.mean(np.abs(pred-hue_labels_test))
    print("hue effectiveness: ", hue_eff1)
    
    #hue internevtion
    print("hue intervention")
    np.random.seed(2)
    hue_cf = np.random.random(len(img_test)).reshape(-1,1)
    img_cf = model.cf_generation(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((num_labels_test, hue_cf), axis=1), img_test)
    pred = oracle_hue(img_cf)
    hue_eff2 = np.mean(np.abs(pred-hue_cf))
    print("hue effectiveness: ", hue_eff2)
    pred = oracle_number(img_cf)
    m = tf.keras.metrics.CategoricalAccuracy(name='categorical_accuracy')
    m.update_state(num_labels_test, pred)
    num_eff2 = m.result().numpy()
    print("number effectiveness: ", num_eff2)
    
    # CF ground truth distance with hue
    hue_cf = hue_labels_test_cf
    cf_estimation = model.cf_generation(np.concatenate((num_labels_test, hue_labels_test), axis=1), np.concatenate((num_labels_test, hue_cf), axis=1), img_test)
    l1_cf = 255*np.mean(np.abs(cf_estimation-img_test_cf))
    print("CF ground truth: ", l1_cf)
    
    return l1_comp, l1_num_rev, l1_hue_rev, num_eff1, hue_eff1, num_eff2, hue_eff2, l1_cf