"""Code to evaluate OOD detection performance of our concept-based approach"""
## plot concept sensitivity statistics across concepts
import sys
import os
import time
import argparse
import numpy as np
import random
import pickle
import joblib
import itertools
import copy
import pandas as pd
import scipy.stats
import scipy.io as sio
from scipy.special import comb
import sklearn.metrics as metrics
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
from scipy.stats import norm
sns.set_style("whitegrid") #darkgrid

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.layers as layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow.keras.metrics as metrics

import ipca_v2
import AwA2_helper
from test_baselines import run_eval


from utils.test_utils import arg_parser, prepare_data, get_measures
from utils.test_utils import ConceptProfiles
from utils.test_utils import get_recovered_features
from utils.ood_utils import run_ood_over_batch
from utils.stat_utils import compute_pval, bayes_posterior, FLD, multivar_separa
from utils.plot_utils import plot_stats, plot_per_class_stats, plot_score_distr
from utils.plot_utils import plot_tsne
from utils import log


softmax = layers.Activation('softmax')

def remove_duplicate_concepts(topic_vec):
    # Remove one concept vector if there are two vectors where the dot product is over 0.95
    # topic_vec: dim=(dim_features, n_concepts) (2048, 70)

    n_concept = topic_vec.shape[1]
    thresh = 0.95
    topic_vec_n = topic_vec/(np.linalg.norm(topic_vec,axis=0,keepdims=True)+1e-9)

    topic_vec_n_dot = np.transpose(topic_vec_n) @ topic_vec_n - np.eye(n_concept)
    dict_similar_topic = {}
    idx_delete = set()
    for i in range(n_concept):
        ith_redundant_concepts = [j for j in range(n_concept) if topic_vec_n_dot[i][j] >= 0.95]
        dict_similar_topic[i] = ith_redundant_concepts
        
        ith_redundant_concepts = [x for x in ith_redundant_concepts if x > i]
        idx_delete.update(ith_redundant_concepts)
    idx_delete = list(idx_delete)

    print(dict_similar_topic)
    print(idx_delete)

    topic_vec = np.delete(topic_vec, idx_delete, axis=1)


    dict_topic_mapping = {}
    count = 0
    for i in range(n_concept):
        if i in idx_delete:
            dict_topic_mapping[i] = None
        else:
            dict_topic_mapping[i] = count
            count += 1
    print('concept mapping between before/after duplicate removal......')
    print(dict_topic_mapping)

    return topic_vec, dict_similar_topic, dict_topic_mapping

def visualize_nn(test_loader, topic_vec, f_test, save_dir, logger):
    num_concept = topic_vec.shape[1]

    f_test_n = f_test/(np.linalg.norm(f_test,axis=3,keepdims=True)+1e-9)
    topic_vec_n = topic_vec/(np.linalg.norm(topic_vec,axis=0,keepdims=True)+1e-9)
    topic_prob = np.matmul(f_test_n,topic_vec_n)
    n_size = np.shape(f_test)[1]
    for i in range(num_concept):
      savepath = os.path.join(save_dir,'concept'+str(i))
      if not os.path.isdir(savepath):
        os.mkdir(savepath)

      neighbors_num = 15
      ind = np.argpartition(topic_prob[:,:,:,i].flatten(), -neighbors_num)[-neighbors_num:]
      sim_list = topic_prob[:,:,:,i].flatten()[ind]
      logger.info(f'[ID TEST: CONCEPT {i}] top-{neighbors_num} scores: {sim_list}')
      for jc,j in enumerate(ind):
        j_int = int(np.floor(j/(n_size*n_size)))
        a = int((j-j_int*(n_size*n_size))/n_size)
        b = int((j-j_int*(n_size*n_size))%n_size)
        f1 = None #savepath+'/concept_full_{}_{}.png'.format(i,jc)
        f2 = savepath+'/concept_{}_{}.png'.format(i,jc) 
        if sim_list[jc]>0.80:
            x_test_filename = test_loader.filepaths[j_int]
            AwA2_helper.copy_save_image(x_test_filename,f1,f2,a,b)



def prepare_profiles(feature_model, topic_vec, num_classes, args, logger):
    # profiling using validation data
    #profile_path = "{}/AwA2_train_concept_dict.pkl".format(args.result_dir)
    profile_path = "{}/AwA2_val_concept_dict.pkl".format(args.result_dir)
    if not os.path.exists(profile_path):
        logger.info("Profiling the distribution of concept scores from train set...")

        tf.random.set_seed(0)
        datagen = ImageDataGenerator(rescale=1./255.)
                                                #rotation_range=40,
                                                #width_shift_range=0.2, height_shift_range=0.2,
                                                #shear_range=0.2, zoom_range=0.2,
                                                #horizontal_flip=True)
        data_loader = datagen.flow_from_directory("./data/Animals_with_Attributes2/val", \
                                                batch_size=350, target_size=(224,224), \
                                                class_mode='categorical', \
                                                shuffle=False)

        ConceptP = ConceptProfiles()
        ConceptP.setUp(num_classes, data_loader)
        ConceptP.prepare_concept_dict(feature_model, topic_vec)
        concept_dict = ConceptP.concept_dict

        #LOAD_DIR = 'data/Animals_with_Attributes2'
        #y_train = np.load(LOAD_DIR+'/y_train.npy')
        #y_train = np.argmax(y_train, axis=1)

        logger.info("Saving concept profiles of AwA2 train set in {}".format(profile_path))
        with open(profile_path,'wb') as f:
            pickle.dump(concept_dict, f)

    else:
        logger.info("Loading concept profiles of AwA2 train set from {}".format(profile_path))
        with open(profile_path,'rb') as f:
            concept_dict = pickle.load(f)

    return concept_dict


def explain_topK(scores, top_k, separa, concept_idx=np.array([]), figname=None):
    """
    Plot bar graph of top-k largest average concept scores
    :param scores: concept scores, dim=(N,num_concepts)
    :param top_k: interested in printing top-k highest concept scores
    :param separa: separability score averaged across concepts or per-class multivariate separability
    """
    s_mean = np.mean(scores, axis=0)
    if not concept_idx.any():
        concept_idx = np.argsort(np.abs(s_mean))[::-1][:top_k] 
    
    num_types = 1 
    num_concepts = top_k
    bar_width = 0.35
    index = np.arange(num_concepts) * bar_width * (num_types + 1)

    fig, ax = plt.subplots(figsize=(3*top_k/5,3))
    bar = ax.bar(index + 0 * bar_width, s_mean[concept_idx],
            bar_width, yerr=np.std(scores[:,concept_idx],axis=0))
    #ax.set_title('Top-{0} Prominent Concepts, Separability: {1:.3f}'.format(top_k, separa))#,fontsize=7)
    ax.set_title('Separability: {1:.3f}'.format(top_k, separa))
    ax.set_ylabel('Concept score')
    ax.set_ylim(-0.25,1)
    ax.set_xticks(index + num_types * bar_width / 2)
    ax.set_xticklabels(['C{}'.format(c) for c in concept_idx]) #, rotation=45)
    fig.tight_layout()
    plt.savefig(figname)
    plt.close()

def explain_relative(scores, labels, separa, figname, figname_dist, top_k=6):
    """
    scores: dictionary of concept scores of groundtruth ID, groundtruth OOD, ID -> ID, ID -> OOD, OOD -> ID, OOD -> OOD
    labels: labels for different types of scores
    separa: separability scores, dim=(num_concepts,)
    """
    # concepts with top-k separability scores
    #concept_idx = np.argsort(separa)[::-1][:top_k] # top K: from largest to smallest value
    concept_idx = np.arange(top_k)
    num_types = len(labels)
    num_concepts = top_k
    bar_width = 0.35
    # create location for each bar. scale by an appropriate factor to ensure 
    # the final plot doesn't have any parts overlapping
    index = np.arange(num_concepts) * bar_width * (num_types + 1)

    fig, ax = plt.subplots(figsize=(3*top_k/2,3))
    for i in range(num_types):
        bar = ax.bar(index + i * bar_width, np.mean(scores[labels[i]][:,concept_idx],axis=0),
                bar_width, yerr=np.std(scores[labels[i]][:,concept_idx],axis=0), label=labels[i])
    ax.set_title('Concept scores for each concept and ID/OOD data')
    ax.set_ylabel('Concept score')
    ax.set_xticks(index + num_types * bar_width / 2)
    ax.set_xticklabels(['concept {}'.format(c) for c in concept_idx], rotation=45)
    ax.legend()
    fig.tight_layout()
    plt.savefig(figname)
    plt.close()

def save_images(filepaths, figname, k=5):
    #if not len(filepaths):
    #    return

    k = min(k, len(filepaths))
    np.random.shuffle(filepaths)
    """
    fig, axes = plt.subplots(1,k)
    count = 0
    for f in filepaths:
        img = Image.open(f).resize((100,100), Image.ANTIALIAS)
        axes[count].imshow(img)
        #ax2.set_title("ID image", size=10, color='b')
        axes[count].axis('off')
    
        count += 1
        if count >= k:
            break
    fig.savefig(figname+)
    plt.close()
    """
    count = 0
    for f in filepaths:
        img = Image.open(f).resize((200,200), Image.ANTIALIAS)
        #plt.imshow(img)
        #plt.tight_layout()
        #plt.savefig(figname+'_'+str(count)+'.png')
        #plt.close()
        img.save(figname+'_'+str(count)+'.png')
        count += 1
        if count >= k:
            break

def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    logger = log.setup_logger(args, filename="eval_{}.log".format(args.score))
    in_loader, out_loader = prepare_data(args, logger)
    LOAD_DIR = 'data/Animals_with_Attributes2'
    TOPIC_PATH = os.path.join(args.result_dir,'topic_vec_inceptionv3.npy')
    INPUT_SHAPE = (args.out_data_dim, args.out_data_dim)
    TRAIN_DIR = "./data/Animals_with_Attributes2/train"
    OOD_DATASET = args.out_data
    N_CLASSES = 50
    N_CONCEPTS_ORIG = 100 #np.shape(topic_vec_orig)[-1]
    _ = 0

    if args.score == 'ODIN':
        args.batch_size = 200

    if not os.path.exists(os.path.join(args.result_dir, 'plots')):
        os.makedirs(os.path.join(args.result_dir, 'plots'))
    if not os.path.exists(os.path.join(args.result_dir, 'explanations')):
        os.makedirs(os.path.join(args.result_dir, 'explanations'))
    if not os.path.exists(os.path.join(args.result_dir, 'explanations', args.out_data+'_'+args.score)):
        os.makedirs(os.path.join(args.result_dir, 'explanations', args.out_data+'_'+args.score))
    explain_dir = os.path.join(args.result_dir, 'explanations', args.out_data+'_'+args.score)
    
    ## load trained_model
    logger.info(f"Loading model from {args.model_path}")
    feature_model, predict_model = AwA2_helper.load_model_inception_new(_, _, input_size=INPUT_SHAPE, pretrain=True, n_gpus=1, modelname=args.model_path)
    
    #in_test_features = feature_model.predict(in_loader)
    #out_test_features = feature_model.predict(out_loader)
    in_test_features = feature_model.predict(in_loader.next()[0])
    #N_IN, N_OUT = in_test_features.shape[0], out_test_features.shape[0]
    ## load topic model
    topic_model = ipca_v2.TopicModel(in_test_features, N_CONCEPTS_ORIG, thres=0.2, predict=predict_model)
    topic_model(in_test_features)
    topic_model.load_weights(args.result_dir+'/topic_latest.h5', by_name=True)
    ## load topic_vec
    topic_vec_orig = topic_model.layers[0].get_weights()[0]
    #topic_vec_orig = np.load(TOPIC_PATH)  # (512, 25) 
    logger.info(f'Number of concepts before removing duplicate ones: {str(N_CONCEPTS_ORIG)}')

    topic_vec, _, concept_map = remove_duplicate_concepts(topic_vec_orig)
    N_CONCEPTS = np.shape(topic_vec)[-1] # 25
    logger.info(f'Number of concepts after removing duplicate ones: {str(N_CONCEPTS)}')

    """
    ######################################
    ## Visualize the nearest neighbors
    if args.visualize:
        visualize_nn(in_loader, topic_vec, in_test_features, args.result_dir, logger)


    #######################################
    ## Evaluating the difference between two worlds......
    #y_test = np.argmax(np.load('data/Animals_with_Attributes2/y_test.npy'), axis=1) # true labels

    logger.info("[ID TEST RECOVERED] performance of target OOD detector with test set...")
    in_test_scores_recov, out_test_scores_recov, _, auroc_recov = run_eval(feature_model, topic_model, in_loader, out_loader, logger, args, N_CLASSES)
    savefig = os.path.join(args.result_dir, 'plots', '{}_recov_AwA2_test_{}_test.jpg'.format(args.score, args.out_data))
    plot_stats(in_test_scores_recov, out_test_scores_recov, savename=savefig)
    """

    ###########################################
    ## Load necessary data and ConceptSHAP results.............
    shap_path = os.path.join(explain_dir,'{}_SHAP.pkl'.format(args.score))
    with open(shap_path,'rb') as f:
        shap_expl = pickle.load(f)
    in_test_yhat = shap_expl['in_yhat']
    out_test_yhat = shap_expl['out_yhat']
    in_test_concepts = shap_expl['in_concepts']
    out_test_concepts = shap_expl['out_concepts']
    in_test_scores = shap_expl['in_scores']
    out_test_scores = shap_expl['out_scores']
    thres95 = shap_expl['thres']
    N_CONCEPTS = in_test_concepts.shape[1]

    #savefig = os.path.join(args.result_dir, 'plots', '{}_AwA2_test_{}_test.jpg'.format(args.score, args.out_data))
    #plot_stats(in_test_scores, out_test_scores, savename=savefig)


    #######################################
    ## Load separability of concepts......
    separa_path = os.path.join(args.result_dir, 'separability_{}_AwA2_{}_multiv.npy'.format(args.score, args.out_data))
    separa = np.load(separa_path, allow_pickle=True).item()

    #######################################
    ## Generating explanations.....
    separa_global = separa['global']
    logger.info(f'[GLOBAL SEPARABILITY] multivariate separability: {separa_global}')
    separa_class = np.array([separa['class'+str(i)] for i in range(N_CLASSES)], dtype=np.float64)
    logger.info(f'[PER-CLASS SEPARABILIRY] averaged separability: {np.nanmean(separa_class)}')
    classes = np.argsort(separa_class)[::-1]
    #classes = np.append(classes[:5],[c for c in classes[-9:] if separa_class[c]]) # omitting classes with separability==0
    classes = np.delete(classes, np.where(np.isnan(separa_class[classes]))[0])

    logger.info(f'classes with highest and lowest separabilities...: {classes}')
    for i in classes:
        shap_class = shap_expl[f'shap_class_class{i}']
        shap_concept_orig_class = np.argsort(shap_class)[::-1]

        shap = shap_expl[f'shap_detect_class{i}']
        shap_concept_orig = np.argsort(shap)[::-1]

        shap_concept_class = np.array([], dtype=np.int)
        shap_concept = np.array([], dtype=np.int)
        
        for c in range(len(shap_concept_orig_class)):
            c_update_class = concept_map[shap_concept_orig_class[c]]
            if c_update_class:
                shap_concept_class = np.append(shap_concept_class, c_update_class) 

        for c in range(len(shap_concept_orig)):
            c_update = concept_map[shap_concept_orig[c]]
            if c_update:
                shap_concept = np.append(shap_concept, c_update)
        
        idx_in = np.where(in_test_yhat == i)[0]
        idx_out = np.where(out_test_yhat == i)[0]
        in_concepts_ith = in_test_concepts[idx_in,:] # concept scores of ID data classified as class i
        out_concepts_ith = out_test_concepts[idx_out,:] # concept scores of OOD data classified as class i
           
        if len(idx_in) < N_CONCEPTS or len(idx_out) < N_CONCEPTS:
            continue

        # indices for OOD detection results
        idx_IN_IN = in_test_scores[idx_in] >= thres95   # ID detected as ID
        idx_IN_OUT = ~idx_IN_IN                 # ID detected as OOD
        idx_OUT_OUT = out_test_scores[idx_out] < thres95 # OOD detected as OOD
        idx_OUT_IN = ~idx_OUT_OUT               # OOD detected as ID

        print(np.sum(idx_IN_IN))
        print(np.sum(idx_IN_OUT))
        print(np.sum(idx_OUT_OUT))
        print(np.sum(idx_OUT_IN))

        k=10
        print(f'class{i} classification | top-{k} concepts: {shap_concept_class[:k]}')
        print(f'class{i} classification | top-{k} conceptSHAP: {shap_class[shap_concept_orig_class[:k]]}')

        print(f'class{i} detection | top-{k} concepts: {shap_concept[:k]}')
        print(f'class{i} detection | top-{k} conceptSHAP: {shap[shap_concept_orig[:k]]}')
        in_concepts_ith_detected = np.r_[in_concepts_ith[idx_IN_IN], out_concepts_ith[idx_OUT_IN]]
        out_concepts_ith_detected = np.r_[in_concepts_ith[idx_IN_OUT], out_concepts_ith[idx_OUT_OUT]]
        separa_topk = multivar_separa(in_concepts_ith_detected[:,shap_concept[:k]], out_concepts_ith_detected[:,shap_concept[:k]])
        print(f'class{i} | top-{k} separability: {separa_topk}')
        explain_topK(in_concepts_ith_detected, top_k=k, separa=separa_topk, concept_idx=shap_concept[:k], 
                    figname=os.path.join(explain_dir,'class{}_AwA2_top{}_detected_{}.jpg'.format(i, k, args.score)))
        explain_topK(out_concepts_ith_detected, top_k=k, separa=separa_topk, concept_idx=shap_concept[:k],
                    figname=os.path.join(explain_dir,'class{}_{}_top{}_detected_{}.jpg'.format(i, args.out_data, k, args.score)))

        """
        # relative comparison
        scores = {}
        labels = ['ID', 'OOD', 'ID->OOD', 'OOD->ID']
        scores[labels[0]] = in_concepts_ith
        scores[labels[1]] = out_concepts_ith
        scores[labels[2]] = in_concepts_ith[idx_IN_OUT]
        scores[labels[3]] = out_concepts_ith[idx_OUT_IN]
        explain_relative(scores, labels, separa['class'+str(i)], 
        figname=os.path.join(explain_dir,'class{}_AwA2_{}_top{}_separability_{}.jpg'.format(i,args.out_data,k,args.score)),
        figname_dist=os.path.join(explain_dir,'class{}_AwA2_{}_distribution.jpg'.format(i,args.out_data)),
        top_k=k)
        """
            
        #if np.sum(idx_IN_IN)<5 or np.sum(idx_IN_OUT)<5 or np.sum(idx_OUT_OUT)<5 or np.sum(idx_OUT_IN)<5:
        #    continue
        # visualize example ID/OOD images
        in_files_ith = np.array(in_loader.filepaths)[idx_in]
        out_files_ith = np.array(out_loader.filepaths)[idx_out]
        save_images(in_files_ith[idx_IN_IN], figname=os.path.join(explain_dir,'class{}_{}_AwA2_IN'.format(i, args.score)))
        save_images(in_files_ith[idx_IN_OUT], figname=os.path.join(explain_dir,'class{}_{}_AwA2_OUT'.format(i, args.score)))
        save_images(out_files_ith[idx_OUT_OUT], figname=os.path.join(explain_dir,'class{}_{}_{}_OUT'.format(i, args.score, args.out_data)))
        save_images(out_files_ith[idx_OUT_IN], figname=os.path.join(explain_dir,'class{}_{}_{}_IN'.format(i, args.score, args.out_data)))

if __name__ == '__main__':
    parser = arg_parser()
    parser.add_argument('--gpu', required=True, type=str)
    parser.add_argument('--result_dir', type=str, help='path to directory where results from concept learning are stored', default='results/AwA2_coher10_energy1_feat0.1_concept100')
    parser.add_argument('--visualize', '-visualize', action='store_true', help='whether to visualize nearest neighbors')
    parser.add_argument('--out_data_dim', type=int, default=224, help='dimension of ood data')
    parser.add_argument('--score', choices=['MSP', 'ODIN', 'Energy', 'Mahalanobis', 'KL_Div'], default='Energy')
    parser.add_argument('--temperature_odin', default=1000, type=int,
                        help='temperature scaling for odin')
    parser.add_argument('--epsilon_odin', default=0.0, type=float,
                        help='perturbation magnitude for odin')
    parser.add_argument('--temperature_energy', default=1, type=int,
                        help='temperature scaling for energy')

    main(parser.parse_args())
