# lint as: python3
"""Main file to run AwA experiments."""
import os
import toy_helper_v2
import ipca_v2
import tensorflow.keras as keras
import tensorflow.keras.backend as K
import tensorflow as tf

import numpy as np
from sklearn.decomposition import PCA
#from fbpca import diffsnorm, pca
from sklearn.decomposition import TruncatedSVD
from sklearn.utils.extmath import randomized_svd
from absl import app


# tf.config.run_functions_eagerly(True)
# print(tf.executing_eagerly())
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"
def main(_):
  n_concept = 8 #5
  n_cluster = 5
  n = 50000 # 10000
  n0 = int(n * 0.8)
  batch_size = 128
  pretrain = False
  verbose = True
  thres = 0.1
  # create dataset
  _, _, x, concept = toy_helper_v2.create_dataset(n_sample=n)

  # Loads data.
  _, y, _ = toy_helper_v2.load_xyconcept_multiclass(n, pretrain=True, n_set=1)
  print(x[0])
  y = tf.keras.utils.to_categorical(y)

  if not pretrain:
    x_train = x[:n0, :].astype('float32')
    x_val = x[n0:, :].astype('float32')
    x_train = tf.keras.applications.inception_v3.preprocess_input(x_train)
    x_val = tf.keras.applications.inception_v3.preprocess_input(x_val)
  y_train = y[:n0,:]
  y_val = y[n0:,:]
  
  # Loads model
  if not pretrain:
    width, height = np.shape(x_train)[1:-1]
    feature_model, predict_model = toy_helper_v2.load_model_stm_new_multiclass(
        x_train, y_train, x_val, y_val, width, height, pretrain=pretrain)
  else:
    feature_model, predict_model = toy_helper_v2.load_model_stm_new_multiclass(_, _, _, _, pretrain=pretrain)


  print('\n\nfeature model')
  print(feature_model.summary())
  print('\n\npredict model')
  print(predict_model.summary())
  
  # get feature
  if not pretrain:
    all_feature = feature_model.predict(x)
    np.save('results/work_toy_test/all_feature_best.npy', all_feature)
  else:
    all_feature = np.load('results/work_toy_test/all_feature_best.npy')
  f_train = all_feature[:n0, :]
  f_val = all_feature[n0:, :]
  print(f_train.shape)   # (?, 2, 2, 64)
  N = f_train.shape[0]   # 60000*0.8
  trained = True
  para_array = [1.0]


  loss_val, acc_val = predict_model.evaluate(f_val, y_val)
  print('after splitting the original model...')
  print('Loss of the trained original model: '+str(loss_val))
  print('Accuracy of the trained original model: '+str(acc_val))




  # for n_concept in range(5,6,1):
  for n_concept in range(n_concept,n_concept+1,1):
    if not trained:
      for count,para in enumerate(para_array):
        if count:
          load = True
        else:
          load = False

        topic_model_pr, optimizer_reset, optimizer, \
          topic_vector,  n_concept, f_input = ipca_v2.topic_model_new_toy(predict_model,
                                        f_train,
                                        y_train,
                                        f_val,
                                        y_val,
                                        n_concept,
                                        verbose=verbose,
                                        #metric1=['binary_accuracy'],
                                        #loss1=tf.keras.losses.binary_crossentropy,
                                        thres=thres,
                                        load=False,
                                        para=para)

        
        topic_model_pr.fit(
          f_train,
          y_train,
          batch_size=batch_size,
          epochs=30,
          validation_data=(f_val, y_val),
          verbose=verbose)
        #topic_model_pr.save_weights('data/toy_data/latest_topic_toy.h5')
        topic_model_pr.save_weights('results/work_toy_test/latest_topic_toy_8concepts.h5')
        #topic_model_pr.load_weights('results/work_toy_test/latest_topic_toy_8concepts.h5')
        topic_model_pr.evaluate(f_val, y_val)
        

        topic_vec = topic_model_pr.layers[1].get_weights()[0]
        recov_vec = topic_model_pr.layers[-3].get_weights()[0] # dim=(500,64)
        topic_vec_n = topic_vec/(np.linalg.norm(topic_vec,axis=0,keepdims=True)+1e-9)  # dim=(64,5)

        acc = toy_helper_v2.get_groupacc_max(
            topic_vec_n,
            f_train,
            f_val,
            concept,
            n_concept,
            n_cluster,
            n0,
            verbose=verbose)
        print('accuracy:' + str(acc))

        """
        ipca_v2.get_completeness(predict_model,
                           f_train,
                           y_train,
                           f_val,
                           y_val,
                           n_concept,
                           topic_vec_n[:,:n_concept],
                           verbose=verbose,
                           epochs=10,
                           metric1=['binary_accuracy'],
                           loss1=keras.losses.binary_crossentropy,
                           thres=thres,
                           load='data/toy_data/latest_topic_toy.h5')
        """   
    
    # visualize the nearest neighbors
    # x = np.load('data/toy_data/x_data_small.npy')
    x = np.load('data/toy_data/x_data_val.npy')
    print(np.shape(x))
    # f_train_n = f_train[:10000]/(np.linalg.norm(f_train[:10000],axis=3,keepdims=True)+1e-9)
    f_val_n = f_val[:4000]/(np.linalg.norm(f_val[:4000],axis=3,keepdims=True)+1e-9)
    topic_vec_n = topic_vec/(np.linalg.norm(topic_vec,axis=0,keepdims=True)+1e-9)
    # topic_prob = np.matmul(f_train_n,topic_vec_n)
    topic_prob = np.matmul(f_val_n,topic_vec_n) #dim=(4000,2,2,5)
    # print('topic_prob printed')
    # print(topic_prob[:5])

    n_size = 2 #4  # (topic vec: 10000x2x2x5)
    n_neighbor = 15
    for i in range(n_concept): # ith concept
      ind = np.argpartition(topic_prob[:,:,:,i].flatten(), -n_neighbor)[-n_neighbor:]  # index for top-10 vals
      sim_list = topic_prob[:,:,:,i].flatten()[ind]                    # top-10 vals (not necessarily sorted)
      print(sim_list)
      for jc,j in enumerate(ind):  # j: idx of jc_th topic_prob val 
        j_int = int(np.floor(j/(n_size*n_size)))   # idx j belongs to j_int th image out of 4000 val images 
        a = int((j-j_int*(n_size*n_size))/n_size)  # ath receptive field along axis=0 (vertical axis)
        b = int((j-j_int*(n_size*n_size))%n_size)  # bth receptive field along axis=1 (horizontal axis)
        f1 = 'results/work_toy_test/concepts/concept_full_{}_{}.png'.format(i,jc)  #'/volume00/jason/concept_stm/work_toy_test/concept_full_{}_{}.png'
        f2 = 'results/work_toy_test/concepts/concept_{}_{}.png'.format(i,jc) # '/volume00/jason/concept_stm/work_toy_test/concept_{}_{}.png'
        #if sim_list[jc]>0.95:
        toy_helper_v2.copy_save_image(x[j_int,:,:,:],f1,f2,a,b)
      # np.save('data/toy_data/topic_vec_toy.npy',topic_vec)
      # np.save('data/toy_data/recov_vec_toy.npy',recov_vec)
      np.save('data/toy_data/topic_vec_toy_8concepts.npy',topic_vec)
      np.save('data/toy_data/recov_vec_toy_8concepts.npy',recov_vec)
      
    else:
      topic_vec = np.load('data/toy_data/topic_vec_toy.npy')
      recov_vec = np.load('data/toy_data/recov_vec_toy.npy')
    
if __name__ == '__main__':
  app.run(main)
