import tensorflow as tf
import dataset_funcs as ds
import vrmodel_funcs as vrm
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score 
import numpy as np
from tensorflow.keras import backend as K
import pickle
import csv
import os
print('ana')

loss_dir = './nets/losses/'
fv_dir = './nets/firstvals/'

model_dir='./nets/'
fig_dir = './figs/'
act_dir = './activity/'
export_dir = './saved_model'

def mse(A,B):
    return np.mean((A-B)**2)

def plot_training(net):
    train_loss = []
    val_loss = []
    if 'SupAl3' in net:
       train1_loss = [];  train2_loss = []; train0_loss = []
       val1_loss = []; val2_loss = []; val0_loss = []        

    filename = [f for f in os.listdir(loss_dir) if net in f][0]
    net = filename[:-10]
    with open(loss_dir+filename) as csvfile:
        reader = csv.DictReader(csvfile)
        for r in reader:
           if "UnsImg" in net:
               losses = r['epoch;binary_crossentropy;loss;val_binary_crossentropy;val_loss'].split(';')
               train_loss.append(float(losses[1]))
               val_loss.append(float(losses[3]))
           elif "UnsVim" in net:
               losses = r['epoch;kl_loss;loss;reconstruction_loss;val_kl_loss;val_reconstruction_loss;val_total_loss'].split(';')
               train_loss.append(float(losses[2]))
               val_loss.append(float(losses[4]))
           elif "UnsCpc" in net:
               losses = r['epoch;binary_accuracy;loss;val_binary_accuracy;val_loss'].split(';') 
               train_loss.append(float(losses[1]))
               val_loss.append(float(losses[3]))
           elif 'SupAl3' in net:
               losses = r['epoch;loss;tf_op_layer_Mul_1_loss;tf_op_layer_Mul_2_loss;tf_op_layer_Mul_loss;val_loss;val_tf_op_layer_Mul_1_loss;val_tf_op_layer_Mul_2_loss;val_tf_op_layer_Mul_loss'].split(';')
               train_loss.append(float(losses[1]))
               val_loss.append(float(losses[5])) 
               train0_loss.append(float(losses[2]))
               val0_loss.append(float(losses[6])) 
               train1_loss.append(float(losses[3]))
               val1_loss.append(float(losses[7])) 
               train2_loss.append(float(losses[4]))
               val2_loss.append(float(losses[8]))               
           else:
               losses = r['epoch;loss;val_loss'].split(';')
               train_loss.append(float(losses[1]))
               val_loss.append(float(losses[2]))
    plt.title(str(np.load(fv_dir+net+'Firstval.npy')))
    plt.plot(np.arange(1,len(train_loss)+1),train_loss)
    plt.plot(np.arange(1,len(train_loss)+1),val_loss); plt.legend(['train','val'])
    plt.savefig(fig_dir+net+'_train.png'); plt.close()
    if 'SupAl3' in net:
        plt.subplot(1,3,1)
        plt.plot(np.arange(1,len(train1_loss)+1),train1_loss)
        plt.plot(np.arange(1,len(train1_loss)+1),val1_loss); plt.legend(['train','val'])
        plt.subplot(1,3,2)
        plt.plot(np.arange(1,len(train1_loss)+1),train2_loss)
        plt.plot(np.arange(1,len(train1_loss)+1),val2_loss);
        plt.subplot(1,3,3)
        plt.plot(np.arange(1,len(train1_loss)+1),train2_loss)
        plt.plot(np.arange(1,len(train1_loss)+1),val2_loss); 
        plt.savefig(fig_dir+net+'allloss'+'_train.png'); plt.close()

def do_perf_analysis(net,batch_size=2048,shuffle_buffer=None):

    if 'Tsk' in net:
        output = 'task'
    elif 'Act' in net:
        output = 'action'
    elif 'Rwd' in net:
        output = 'reward'
    if 'Tch' in net:
        output = 'touch'
    elif 'Acc' in net:
        output = 'acceleration'
    elif 'App' in net:
        output = 'appendages'
    elif 'Jnt' in net:
        output = 'joints'
    elif 'Zax' in net:
        output = 'zaxis'
    elif 'Prp' in net:
        output = 'proprio'
    elif 'Jpr' in net:
        output = 'jprop'
    elif 'Uns' in net:
        if 'Cpc' in net:
            cpc_acceval(net,batch_size=batch_size)
        else:
            autoenc_viz(net,batch_size=batch_size)
        return None
    elif 'Al' in net:
        multiobj_perf_analysis(net,batch_size=batch_size,shuffle_buffer=shuffle_buffer); return None
    filename = [f for f in os.listdir(loss_dir) if net in f][0][:-10]
    reconstructed_model = tf.keras.models.load_model(model_dir+filename)
    test = ds.load_and_transform_for_test(output,batch_size,shuffle_buffer=shuffle_buffer)
    if shuffle_buffer is not None:
        test.shuffle(shuffle_buffer)
    loss = reconstructed_model.evaluate(test,verbose=0)
    print('Test loss for '+net+' is '+str(loss))

    batch = list(test.as_numpy_iterator())[0]
    inputs = batch[0]; labels = batch[1]
    #inputs = (inputs[0]*0,inputs[1]) #TEST
    outputs = reconstructed_model(inputs,training=False).numpy()

    if False: # output == 'reward':
        plt.figure()
        h_inds = np.where(labels>9.5)[0]
        l_inds = np.where(labels<=9.5)[0]
        try:
          plt.subplot(2,1,1)
          h_labels = labels[h_inds]
          print(h_labels)
          h_outputs = outputs[h_inds]
          label_max = np.max(h_labels)
          bins=np.arange(0,label_max+label_max/25.,label_max/25.)
          plt.hist(h_labels,color='k',bins=bins,alpha=.5)
          plt.hist(h_outputs,color='r',bins=bins,alpha=.5)
        except:
          print('test batch  has no high reward')
        plt.subplot(2,1,2)
        l_labels = labels[l_inds]
        l_outputs = outputs[l_inds]
        label_max = np.max(l_labels)
        bins=np.arange(0,label_max+label_max/25.,label_max/25.)
        plt.hist(l_labels,color='k',bins=bins,alpha=.5)
        plt.hist(l_outputs,color='r',bins=bins,alpha=.5)
    if output == 'task':
        outputs = np.argmax(outputs,1)
        plt.figure()
        #bins=[-.5,.5,1.5,2.5,3.5]
        #plt.hist(labels,color='k',bins=bins,alpha=.5)
        #plt.hist(outputs,color='r',bins=bins,alpha=.5)
        labels=labels.flatten()

        confusion_mat = np.zeros((4,4))
        for li in range(4):
            for oi in range(4):
                tlabs = np.where(labels==li)[0]
                olabs = np.where(outputs==oi)[0]

                confusion_mat[li,oi] = len(np.intersect1d(tlabs,olabs))
        plt.imshow(confusion_mat); plt.colorbar()
        plt.title('Test Accuracy: ' + str( np.sum(labels==outputs)/len(labels)))
    else: #output == 'action':
        r2tot=0; body = ['lumbar_e','lumbar_b','lumbar_t','cervical_e','cervical_b','cervical_t','caudal_e','caudal_b','hipL_s','hipL_a','hipL_e','kneeL', 'annkleL','toeL','hipR_s','hipR_a','hipR_e','kneeR','ankleR','toeR','atlas','mandible', 'scapulaL_s','scapulaL_a','scapulaL_e','shoulderL','shoulderSupL','elbowL','wristL','fingerL','scapulaR_s','scapulaR_a','scapulaR_e','shoulderR','shoulderSupR','elbowR','wristR','fingerR']
        a_total = np.shape(outputs)[1]
        for a in range(a_total):
            if a_total == 1:
                plt.subplot(1,1,a+1)
            elif a_total==3:
                plt.subplot(1,3,a+1)
            elif a_total==15:
                plt.subplot(3,5,a+1)
            elif a_total==30:
                plt.subplot(5,6,a+1)
            else:
                plt.subplot(4,10,a+1)
            a_labels = labels[:,a]
            a_outputs = outputs[:,a]
            label_max = np.max(a_labels)
            label_min = np.min(a_labels)
            #print(label_min,label_max)
            #bins=np.arange(label_min,label_max+label_max/25.,label_max/25.)

            plt.scatter(a_labels,a_outputs); plt.plot([label_min,label_max],[label_min,label_max],'k')

            if not (output=='action' or output=='proprio' or output=='jprop'): 
                plt.xlabel('True Value'); plt.ylabel('Predicted Value')
                plt.title(r2_score(a_labels,a_outputs))
            else:
                plt.yticks([]); plt.xticks([])
                plt.title(body[a],fontdict = {'fontsize' : 6})
                r2tot+=r2_score(a_labels,a_outputs)
            #plt.hist(a_labels,color='k',bins=bins,alpha=.5)
            #plt.hist(a_outputs,color='r',bins=bins,alpha=.5)
        #plt.tight_layout()

        plt.savefig(fig_dir+net+'_scats'+str(r2tot/(a+1))+'.png'); plt.close() 

        for a in range(a_total):
            if a_total == 1:
                plt.subplot(1,1,a+1)
            elif a_total==3:
                plt.subplot(1,3,a+1)
            elif a_total==15:
                plt.subplot(3,5,a+1)
            elif a_total==30:
                plt.subplot(5,6,a+1)
            else:
                plt.subplot(4,10,a+1)
            a_labels = labels[:,a]
            a_outputs = outputs[:,a]
            label_max = np.max(a_labels)
            label_min = np.min(a_labels)
            print(label_min,label_max)
            bins=np.arange(label_min,label_max+label_max/25.,label_max/25.)
            plt.yticks([])
            plt.hist(a_labels,color='k',bins=bins,alpha=.5)
            plt.hist(a_outputs,color='r',bins=bins,alpha=.5)
    plt.tight_layout()
    plt.savefig(fig_dir+net+'_hists.png'); plt.close()

def multiobj_perf_analysis(net,batch_size=2048,shuffle_buffer=None):
  if 'All' in net:
    output_list = ['task','reward','action','zaxis']
    load_op = 'all'
  elif 'Al3' in net:
    output_list=  ['task','reward','zaxis']
    load_op = 'all3'

  filename = [f for f in os.listdir(loss_dir) if net in f][0][:-10]
  reconstructed_model = tf.keras.models.load_model(model_dir+filename)
  test = ds.load_and_transform_for_test(load_op,batch_size,shuffle_buffer=shuffle_buffer)
  if shuffle_buffer is not None:
      test.shuffle(shuffle_buffer)
  loss = reconstructed_model.evaluate(test,verbose=0)
  print('Test loss for '+net+' is '+str(loss))

  batch = list(test.as_numpy_iterator())[0]
  inputs = batch[0]; labels_all = batch[1]
  outputs_all = reconstructed_model(inputs,training=False)

  for oi in range(len(output_list)):
    #get activity
    output=output_list[oi]
    outputs=outputs_all[oi].numpy()
    labels = labels_all[oi]
    if output == 'task':
        outputs = np.argmax(outputs,1)
        plt.figure()
        #bins=[-.5,.5,1.5,2.5,3.5]
        #plt.hist(labels,color='k',bins=bins,alpha=.5)
        #plt.hist(outputs,color='r',bins=bins,alpha=.5)
        labels=labels.flatten()

        confusion_mat = np.zeros((4,4))
        for li in range(4):
            for oi in range(4):
                tlabs = np.where(labels==li)[0]
                olabs = np.where(outputs==oi)[0]

                confusion_mat[li,oi] = len(np.intersect1d(tlabs,olabs))
        plt.imshow(confusion_mat); plt.colorbar()
        plt.title('Test Accuracy: ' + str( np.sum(labels==outputs)/len(labels)))
    else: #output == 'action':
        a_total = np.shape(outputs)[1]
        for a in range(a_total):
            if a_total == 1:
                plt.subplot(1,1,a+1)
            elif a_total==3:
                plt.subplot(1,3,a+1)
            elif a_total==15:
                plt.subplot(3,5,a+1)
            elif a_total==30:
                plt.subplot(5,6,a+1)
            else:
                plt.subplot(4,10,a+1)
            a_labels = labels[:,a]
            a_outputs = outputs[:,a]
            label_max = np.max(a_labels)
            label_min = np.min(a_labels)
            #print(label_min,label_max)
            #bins=np.arange(label_min,label_max+label_max/25.,label_max/25.)

            plt.scatter(a_labels,a_outputs); plt.plot([label_min,label_max],[label_min,label_max],'k')
            plt.title(r2_score(a_labels,a_outputs))
            if not output=='action': 
                plt.xlabel('True Value'); plt.ylabel('Predicted Value')
            else:
                plt.yticks([]); plt.xticks([])
            #plt.hist(a_labels,color='k',bins=bins,alpha=.5)
            #plt.hist(a_outputs,color='r',bins=bins,alpha=.5)
        #plt.tight_layout()

    plt.savefig(fig_dir+net+'_'+output+'_scats.png'); plt.close()


def cpc_acceval(net,batch_size=2048,shuffle_buffer=None):
    filename = [f for f in os.listdir(loss_dir) if net in f][0][:-10]
    reconstructed_model = tf.keras.models.load_model(model_dir+filename)
    test = ds.load_and_transform_for_CPCtest(batch_size,shuffle_buffer=shuffle_buffer)
    loss = reconstructed_model.evaluate(test,verbose=0)
    batch = list(test.as_numpy_iterator())[0]
    inputs = batch[0]; labels = batch[1]
    outputs = np.squeeze(reconstructed_model(inputs,training=False).numpy())
    plt.figure(); #plt.title(loss)
    outputs = np.squeeze((outputs>.5)*1); #print(outputs); print(labels);
    print(np.sum((outputs==labels)*1), len(labels))
    plt.title(np.sum((outputs==labels)*1)/len(labels))
    confusion_mat = np.zeros((2,2))
    for li in range(2):
        for oi in range(2):
                tlabs = np.where(labels==li)[0]
                olabs = np.where(outputs==oi)[0]

                confusion_mat[li,oi] = len(np.intersect1d(tlabs,olabs))
    plt.imshow(confusion_mat); plt.colorbar()
    plt.savefig(fig_dir+net+'_hists.png'); plt.close()
    return None

def autoenc_viz(net,batch_size=2048,shuffle_buffer=2048):
    if 'Img' in net or 'Vim' in net:
        output = 'image'
    elif 'Nxf' in net:
        output = 'next_frame'
    filename = [f for f in os.listdir(loss_dir) if net in f][0][:-10]
    if 'UnsVim' in net:
       tf.config.experimental_run_functions_eagerly(False);
    reconstructed_model = tf.keras.models.load_model(model_dir+filename)
    test = ds.load_and_transform_for_test(output,batch_size,shuffle_buffer=shuffle_buffer)
    if shuffle_buffer is not None:
        test.shuffle(shuffle_buffer)

    batch = list(test.as_numpy_iterator())[0]
    inputs = np.array(batch[0]); labels = np.array(batch[1])
    print(np.shape(inputs)); #b=d
    #loss = reconstructed_model.evaluate(inputs, labels,verbose=0) #can't calc the loss for sparse UnsVim
    #print('Test loss for '+net+' is '+str(loss))

    outputs = reconstructed_model(inputs,training=False).numpy()

    for i in range(7):
       plt.subplot(2,7,i+1)
       plt.imshow(inputs[i,:,:,:]); plt.xticks([]); plt.yticks([])
       plt.subplot(2,7,i+1+7)
       plt.imshow(outputs[i,:,:,:]); plt.xticks([]); plt.yticks([])
    plt.title(mse(inputs.flatten(),outputs.flatten()))
    plt.savefig(fig_dir+net+'_recon.png'); plt.close()

    plt.figure(); cols = ['r','g','b']
    for c in range(3):
        plt.subplot(1,3,c+1)
        bins=np.arange(0,1.1,.1)
        plt.hist(inputs[:,:,:,c].flatten(),color='k',bins=bins,alpha=.5)
        plt.hist(outputs[:,:,:,c].flatten(),color=cols[c],bins=bins,alpha=.5)
    plt.savefig(fig_dir+net+'_hists.png'); plt.close()



def record_activity(net,batch_size=2048,shuffle_buffer=None):
    if 'Prp' in net:
        output = 'all_prop'
    else:
        output = 'all'


    if 'UnsVim' in net or 'UnsCpc' in net: #
        filename = [f for f in os.listdir(loss_dir) if net in f][0][:-10]+'_enc'
    elif 'UntRnd' in net:
        filename = net
    else:
        filename = [f for f in os.listdir(loss_dir) if net in f][0][:-10]
    reconstructed_model = tf.keras.models.load_model(model_dir+filename)
    print(reconstructed_model.summary()); #b=d

    test = ds.load_and_transform_for_test(output,batch_size,shuffle_buffer=shuffle_buffer)
    batch = list(test.as_numpy_iterator())[0]
    inputs = batch[0]; batch=batch[1]
    print(inputs,batch); b=d
    rec_layers = []
    for l in range(len(reconstructed_model.layers)):
        if 'pool' in reconstructed_model.layers[l].name:
            if 're_lu' in reconstructed_model.layers[l+1].name:
                rec_layers.append(reconstructed_model.layers[l+1].output); #recording relu after each pooling
                print(reconstructed_model.layers[l+1].name)
        if 'dense' in reconstructed_model.layers[l].name and '_' not in reconstructed_model.layers[l].name:
                rec_layers.append(reconstructed_model.layers[l].output); #recording relu after each pooling
                print(reconstructed_model.layers[l].name)
        #if 'UnsVim' in net and 'dense' in reconstructed_model.layers[l].name: #'sampling' in reconstructed_model.layers[l].name:
        #        rec_layers.append(reconstructed_model.layers[l].output); #recording relu after each pooling
        #        print(reconstructed_model.layers[l].name)


    inp = reconstructed_model.input 
    functor = K.function([inp], [rec_layers] )   # evaluation function
    layer_outs = functor([inputs])
    with open(act_dir+net+'_Act.pickle','wb') as afile:
        if 'Prp' in net:
            pickle.dump([layer_outs,batch[0],batch[1],batch[2],batch[3],inputs],afile)
        else:
            pickle.dump([layer_outs,batch[0],batch[1],batch[2],batch[3],inputs],afile)

def record_VR_activity(batch_size=2048,shuffle_buffer=None): 

    test = ds.load_and_transform_for_test('all',batch_size,shuffle_buffer=shuffle_buffer)
    batch = list(test.as_numpy_iterator())[0]
    inputs = batch[0]; print(batch[0])
    batch = batch[1]

    import tensorflow.compat.v1 as tf1 
    tf1.disable_v2_behavior()

    layers = ['agent_0/step/agent_0/encode/image_encoder/resnet/residual_0_0/Relu:0','agent_0/step/agent_0/encode/image_encoder/resnet/residual_1_0/Relu:0','agent_0/step/agent_0/encode/image_encoder/resnet/residual_2_0/Relu:0','agent_0/step/agent_0/encode/image_encoder/Relu:0']
    with tf1.Session() as sess:
        tf1.saved_model.loader.load(sess, ["tag"], export_dir)
        graph = tf1.get_default_graph()
        norm_im_in = sess.graph.get_tensor_by_name('agent_0/step/agent_0/encode/image_encoder/truediv:0')
        activity=[]
        for l in layers:
            l_act = []
            res_out = sess.graph.get_tensor_by_name(l)
            for f in range(inputs.shape[0]):
               frame = np.expand_dims(inputs[f,:,:,:],0)
               l_act.append(np.array(sess.run(res_out, {norm_im_in: frame})))
            activity.append(np.squeeze(np.array(l_act)))
    with open(act_dir+'RllVir_mn0000_Act.pickle','wb') as afile:
        pickle.dump([[[activity]],batch[0],batch[1],batch[2],batch[3],inputs],afile)
    

    



