

import numpy as np
import tensorflow as tf
import cv2


def resize_images( images, size_after ):
    images_resized = np.zeros([len(images), size_after, size_after, 3]).astype(images.dtype)
    for i in range(len(images)):
        images_resized[i,:,:,:] = cv2.resize(images[i,:,:,:], (size_after, size_after))
    return images_resized



def quantile_norm( x, ref_dist ):
    N = x.shape[0]
    if len(x.shape)>1:
        n_column = x.shape[1]
    else:
        n_column = 1
        x = x.reshape([N,n_column])
    y = np.quantile( ref_dist.flatten(), np.arange(0,1,1/N)+1/(2*N) )
    x2 = x.copy()
    for i in range( n_column ):
        iarg = np.argsort(x[:,i])
        x2[iarg,i] = y
    return x2


def get_x( model, add_images_pp, add_labels ):
    insize = model.input_shape[1]
    w = model.layers[-1].get_weights()
    Nhide = w[0].shape[0]
    Nclas = w[1].shape[0]
    hidden_model = tf.keras.Model(inputs=model.input, outputs=model.layers[-2].output)
    add_id = np.sort(np.unique(add_labels))
    nadd = len(add_id)
    for ic in range(nadd):
        xhide = np.zeros( Nhide )
        ids = np.where( add_labels == add_id[ic] )[0]
        for ik in range( len(ids) ):
            x = add_images_pp[ids[ik],:,:,:]
            x = np.expand_dims(x, 0)
            buf = hidden_model.predict(x).reshape(-1)
            xhide += buf
        xhide = xhide/len(ids)
        if ic==0:
            x_add = xhide.copy().reshape([-1,1])
        else:
            x_add = np.hstack( [ x_add, xhide.reshape([-1,1]) ] )
    return x_add


def change_w(model, x_add, reconstruct):
    wori = model.layers[-1].get_weights()
    Nclas = wori[1].shape[0]
    hidden_model = tf.keras.Model(inputs=model.input, outputs=model.layers[-2].output)

    Nadd = x_add.shape[1]
    w0add = quantile_norm( x_add, wori[0].flatten() )
    w1add = np.zeros(Nadd) + np.median(wori[1])

    wadd = [[]]*2
    if reconstruct==0:
        wadd[0] = np.hstack( [wori[0], w0add] )
        wadd[1] = np.hstack( [wori[1], w1add] )
        Nclas_new = Nclas + Nadd
    elif reconstruct==1:
        wadd[0] = w0add
        wadd[1] = w1add
        Nclas_new = Nadd
    else:
        print('[reconstruct] must be 0, 1, or empty')

    top_layer = tf.keras.layers.Dense(  Nclas_new , activation='softmax' )( hidden_model.output )
    model_add = tf.keras.Model( inputs=hidden_model.input, outputs=top_layer )
    model_add.layers[-1].set_weights(  wadd  )

    return model_add



def add_class( model, add_images_pp, add_labels, reconstruct=0):
    x_add = get_x( model, add_images_pp, add_labels )
    model_add = change_w( model, x_add, reconstruct )
    return model_add



