
#!/usr/bin/env python
# coding: utf-8

#%% -------------------------------------------------------------------------
print('\n\n########## import ##########')

print('Confirm that you have [done] directory under the current working directory')

import numpy as np
import tensorflow as tf
import cv2
import done.functions as done
import os
#print('current directory = ', os.getcwd())


cwdir = os.getcwd()
print( 'cwd =', cwdir )
if os.path.exists( cwdir + './done'):
    print('OK, you have [done] dir, let\'s go!')
else:
    raise Exception('You do not have [done] directory under the cwd. Pleas place [done] directory under the cwd.' )



#%% -------------------------------------------------------------------------
print('\n\n########## CIFAR 100 ##########')

# data load
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()

# label load (also Imnet label)
label_cifar = np.loadtxt("./done/label_cifar100.csv", encoding='utf-8-sig', dtype= "unicode")
label_imnet = np.loadtxt("./done/label_imnet1000.csv", encoding='utf-8-sig', dtype= "unicode")

print('CIFAR-100 imported')




#%% -------------------------------------------------------------------------
print('\n\n########## set training images for class addition ##########')

# Let's say, you have 3, 1, 2 images of baby, caterpillar, sunflower.
print('set CIFAR-training images')

cif_add = ['baby', 'caterpillar', 'sunflower']
k_shot = [3,1,2] 

print('new class =', cif_add)
print('images/class =', k_shot)


cif_id_add = [  np.where(label_cifar == add_class ) for add_class in cif_add ] 
nadd = len(cif_add)

add_images = []

ig = 0
for ic in range(nadd):
    for ik in range(k_shot[ic]):
        ig+=1
        yid = cif_id_add[ic]
        xid = np.where(y_train == yid)[0][ik]
        x = x_train[xid,:,:,:]
        add_images.append(x)

add_images = np.array(add_images)
add_id = np.array( sum( [ [i]*k_shot[i] for i in range(nadd) ], [] ) )

label_imnet_add = np.append( label_imnet, np.array(cif_add) )



#%% -------------------------------------------------------------------------
print('\n\n########## backbone model ##########')

model = tf.keras.applications.efficientnet.EfficientNetB0()
insize = model.input_shape[1]

print('backbone model imported; model output shape =', model.output_shape)




#%% -------------------------------------------------------------------------
print('\n\n########## Resize & preprocess training images ##########')

add_images_resize = done.resize_images( add_images, insize )
add_images_pp = tf.keras.applications.efficientnet.preprocess_input(add_images_resize)

print('images preprocessed for the backbone model')
        


#%% -------------------------------------------------------------------------
print('\n\n########## Class addition by DONE ##########')
print('Training inference num =', sum(k_shot) )

model_add = done.add_class( model, add_images_pp, add_id)
print('It\'s DONE, new classed added; model_add output shape =', model_add.output_shape)



#%% -------------------------------------------------------------------------
print('\n\n########## Test the class-added model; a new class baby image ##########')

# Test by CIFAR baby image (no in original ImageNet class)
print('input a baby image in CIFAR-test (a new class)')

x = x_test[573,:,:,:]
x = cv2.resize(x, (insize, insize))
x = tf.keras.applications.efficientnet.preprocess_input(x)
x_baby = np.expand_dims(x, 0)

# prediction by original model 
y = model.predict(x_baby)[0]
print(f'model_ori; TopID={y.argmax()}, {label_imnet_add[y.argmax()]}, y-val={y.max():.5f}' )

# prediction by added model
y = model_add.predict(x_baby)[0]
print(f'model_add; TopID={y.argmax()}, {label_imnet_add[y.argmax()]}, y-val={y.max():.5f}' )



#%% -------------------------------------------------------------------------
print('\n\n########## Test the class-added model; an original class lion image ##########')

# Test by CIFAR Lion image
print('input a lion image in CIFAR-test (already in the imagenet classes)')

x = x_test[15,:,:,:]
x = cv2.resize(x, (insize, insize))
x = tf.keras.applications.efficientnet.preprocess_input(x)
x_lion = np.expand_dims(x, 0)

# prediction by original model
y = model.predict(x_lion)[0]
print(f'model_ori; TopID={y.argmax()}, {label_imnet_add[y.argmax()]}, y-val={y.max():.5f}' )

# prediction by added model
y = model_add.predict(x_lion)[0]
y = y/y[:1000].sum()  # For providing the same answer (correction for the softmax)
print(f'model_add; TopID={y.argmax()}, {label_imnet_add[y.argmax()]}, y-val(softmax-corrected)={y.max():.5f}' )







