import numpy as np
import os
import sys
import copy
import pickle
from scipy.special import comb
from absl import app, flags
from PIL import Image
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
SMALL_SIZE=12
MEDIUM_SIZE=15
BIGGER_SIZE=20
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
#plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
#plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import ipca_v2
import AwA2_helper
from AwA2_eval import compute_completeness, compute_detection_completeness, compute_separability
from AwA2_eval import compute_concept_scores, remove_duplicate_concepts
from utils.test_utils import get_measures
from utils.ood_utils import run_ood_over_batch
from utils.stat_utils import multivar_separa
import ipdb

FLAGS = flags.FLAGS

#classes = [0, 8, 11, 12, 17, 25, 30, 32, 35, 40, 45]
classes = [0, 8, 11, 12, 17]

def load_data(feature_model):
    # Extract intermediate feature representations...............
    f_in_path = 'results/Animals_with_Attributes2/f_AwA2.npy'
    f_out_path = 'results/Animals_with_Attributes2/f_{}.npy'.format(FLAGS.ood_data)
    
    datagen = ImageDataGenerator(rescale=1.0 / 255.)
    in_loader = datagen.flow_from_directory("/nobackup/jihye/data/Animals_with_Attributes2/test",
                                                batch_size=256,
                                                target_size=(224,224),
                                                class_mode='categorical', shuffle=False)
    out_loader = datagen.flow_from_directory("/nobackup/jihye/data/{}".format(FLAGS.ood_data),
                                                batch_size=256,
                                                target_size=(224,224),
                                                class_mode=None, shuffle=False)
    global in_test_filepaths, out_test_filepaths
    in_test_filepaths = in_loader.filepaths
    out_test_filepaths = out_loader.filepaths

    global in_test_features, out_test_features
    if not os.path.exists(f_in_path):
        datagen = ImageDataGenerator(rescale=1.0 / 255.)
        in_loader = datagen.flow_from_directory("/nobackup/jihye/data/Animals_with_Attributes2/test",
                                                batch_size=256,
                                                target_size=(224,224),
                                                class_mode='categorical', shuffle=False)
        in_test_features = feature_model.predict(in_loader)
        np.save(f_in_path, in_test_features)
    else:
        in_test_features = np.load(f_in_path)

    if not os.path.exists(f_out_path):
        out_test_features = feature_model.predict(out_loader)
        np.save(f_out_path, out_test_features)
    else:
        out_test_features = np.load(f_out_path)


    ## Load necessary data and ConceptSHAP results.............
    global results
    result_path = os.path.join(FLAGS.logdir,'results_{}_{}.pkl'.format(FLAGS.score, FLAGS.ood_data))
    with open(result_path,'rb') as f:
        results = pickle.load(f)
    #in_test_yhat = shap_expl['in_yhat']
    #out_test_yhat = shap_expl['out_yhat']
    #in_test_yhat_recov = shap_expl['in_yhat_recov']
    #in_test_concepts = shap_expl['in_concepts']
    #out_test_concepts = shap_expl['out_concepts']
    #in_test_scores = shap_expl['in_scores']
    #out_test_scores = shap_expl['out_scores']
    #in_test_scores_recov = shap_expl['in_scores_recov']
    #out_test_scores_recov = shap_expl['out_scores_recov']
    #N_CONCEPTS = in_test_concepts.shape[1]

    global y
    y = np.argmax(np.load('/nobackup/jihye/data/Animals_with_Attributes2/y_test.npy'), axis=1) # true labels


def true_to_detected(in_score, out_score, in_concept, out_concept, thres, in_yhat=None, out_yhat=None):
    idx_IN_IN = in_score >= thres
    idx_IN_OUT = ~idx_IN_IN
    idx_OUT_IN = out_score >= thres
    idx_OUT_OUT = ~idx_OUT_IN
    
    in_detect_concept = np.r_[in_concept[idx_IN_IN], out_concept[idx_OUT_IN]]
    out_detect_concept = np.r_[in_concept[idx_IN_OUT], out_concept[idx_OUT_OUT]]

    if in_yhat is not None and out_yhat is not None:
        in_detect_yhat = np.r_[in_yhat[idx_IN_IN], out_yhat[idx_OUT_IN]]
        out_detect_yhat = np.r_[in_yhat[idx_IN_OUT], out_yhat[idx_OUT_OUT]]
        return in_detect_concept, out_detect_concept, in_detect_yhat, out_detect_yhat
    else:
        return in_detect_concept, out_detect_concept


def iterate_mask(concept_mask, topic_vec, topic_model, feature_model, label):

    yhat_in = results['in_yhat']
    yhat_out = results['out_yhat']

    ## modify topic model
    topic_vec_temp = copy.copy(topic_vec)
    topic_vec_temp[:,np.array(concept_mask)==0] = 0
    topic_model.layers[0].set_weights([topic_vec_temp])
    auroc, _, _, _, _ = get_measures(results['in_scores'][:,None], results['out_scores'][:,None])

    # get new predictions
    _, logits_in_new, _ = topic_model(in_test_features)
    _, logits_out_new, _ = topic_model(out_test_features)
    yhat_in_recov_new = tf.math.argmax(logits_in_new, axis=1).numpy()
    yhat_out_recov_new = tf.math.argmax(logits_out_new, axis=1).numpy()

    # get new concept scores
    concept_in = compute_concept_scores(topic_vec, in_test_features).numpy()
    concept_out = compute_concept_scores(topic_vec, out_test_features).numpy()

    comp_class, comp_detect, separa, separa_detected = [], [], [], []
    separa_orig = np.load('results/AwA2_baseline_concept100/separability_{}_AwA2_{}_multiv.npy'.format(FLAGS.score,FLAGS.ood_data), allow_pickle=True).item()
    num_class = 50

    idx_in = np.where(yhat_in == label)[0]
    idx_out = np.where(yhat_out == label)[0]
        
    # 1) compute classification completeness
    comp_class.append(compute_completeness(y[idx_in], yhat_in[idx_in], yhat_in_recov_new[idx_in], num_class))  
        
    # 2) compute detection completeness
    if len(idx_in) == 0 or len(idx_out) == 0:
        comp_detect = np.append(comp_detect, None)

    s_in_new = run_ood_over_batch(None, feature_model, topic_model, FLAGS, num_class, in_test_features[idx_in])
    s_out_new = run_ood_over_batch(None, feature_model, topic_model, FLAGS, num_class, out_test_features[idx_out])
    auroc_recov_new, _, _, _, thres = get_measures(s_in_new[:,None],s_out_new[:,None])
    comp_detect.append(compute_detection_completeness(auroc, auroc_recov_new))

    # 3) compute separability
    # between true-ID vs true-OOD
    separa.append(multivar_separa(concept_in[idx_in,:], concept_out[idx_out,:]).numpy())
    # between detected-ID vs detected-OOD
    concept_in_detected, concept_out_detected = true_to_detected(s_in_new, s_out_new, concept_in[idx_in,:], concept_out[idx_out,:], thres)
    separa_detected.append(multivar_separa(concept_in_detected, concept_out_detected).numpy())

    return np.array(comp_class), np.array(comp_detect), np.array(separa), np.array(separa_detected)



def kernelSHAP(nc, inputs, topic_vec, topic_model, feature_model, label):

    outputs_class, outputs_detect, outputs_separa, outputs_separa_detected = [], [], [], []
    kernel = []
    
    def shap_kernel(n, k):
        """Returns kernel of shapley in KernelSHAP."""
        return (n-1)*1.0/((n-k)*k*comb(n, k) + 1e-5)

    def shap_kernel_adjust(n, k, p=0.5):
        """Returns kernel of shapley in KernelSHAP."""
        return (n-1)*1.0/((n-k)*k*comb(n, k) + 1e-8) / (np.power(p,k)*np.power(1-p,n-k) + 1e-8)

    for concept_mask in inputs:
        comp_class, comp_detect, separa, separa_detected = iterate_mask(concept_mask, topic_vec, topic_model, feature_model, label)
        #compl_class, compl_detect: dim=(len(classes),)
        outputs_class.append(comp_class)
        outputs_detect.append(comp_detect)
        outputs_separa.append(separa)
        outputs_separa_detected.append(separa_detected)
    
        #kernel.append(shap_kernel(nc, np.sum(concept_mask))) 
        kernel.append(shap_kernel_adjust(nc, np.sum(concept_mask)))

    #outputs_class = outputs_class.reshape(-1,len(classes))
    #outputs_detect = outputs_detect.reshape(-1,len(classes))
    #outputs_separa = outputs_separa.reshape(-1, len(classes))
    #outputs_separa_detected = outputs_separa_detected.reshape(-1, len(classes))

    k = np.array(kernel)
    kernel[kernel == np.inf] = 1e+4
    x = np.array(inputs)
    xkx = np.matmul(np.matmul(x.transpose(), np.diag(k)), x)
    
    def compute_shap(outputs, name):
        #y = outputs.reshape(-1, len(classes))
        y = np.array(outputs)
        xky = xky = np.matmul(np.matmul(x.transpose(), np.diag(k)), x)
        xky = np.matmul(np.matmul(x.transpose(), np.diag(k)), y)
        shap = np.matmul(np.linalg.pinv(xkx), xky) 
        """
        if name == 'class':
            xky = np.matmul(np.matmul(x.transpose(), np.diag(k)), x)
            xky = np.matmul(np.matmul(x.transpose(), np.diag(k)), y)
            shap = np.matmul(np.linalg.pinv(xkx), xky)
        else:
            idx = ~np.isnan(outputs[:,i].astype(float))
            print(idx)
            _xkx = np.matmul(np.matmul(x[idx,:].T, np.diag(kernel[idx])), x[idx,:])
            xky = np.matmul(np.matmul(x[idx,:].transpose(), np.diag(kernel[idx])), outputs[idx,i])
            shap = np.matmul(np.linalg.pinv(_xkx), xky)
            shap_expl[f'mask_{name}_class{classes[i]}'] = x[idx,:]
        """
        shap_expl[f'shap_{name}_class{label}'] = shap.flatten()

    #sys.stdout = open(os.path.join(logdir, 'conceptSHAP_{}_{}.txt'.format(FLAGS.score, FLAGS.ood_data)), 'w')
    compute_shap(outputs_class, 'class')
    compute_shap(outputs_detect, 'detect')
    compute_shap(outputs_separa, 'separa')
    compute_shap(outputs_separa_detected, 'separa_detected')
    #sys.stdout.close()

    return shap_expl


def load_shap(expl, i, dupl_mapping, top_k=5):
    shap_class = expl[f'shap_class_class{i}'].flatten() # dim=(num_concept,1)
    shap_detect = expl[f'shap_detect_class{i}'].flatten()
    shap_separa = expl[f'shap_separa_class{i}'].flatten()
    shap_separa_detected = expl[f'shap_separa_detected_class{i}'].flatten()
    mask_class = expl['mask'] # dim=(len(iter), num_concept)
    #mask_detect = expl[f'mask_detect_class{i}']
    #mask_separa = expl[f'mask_separa_class{i}']
    
    idx_class = np.argsort(np.abs(shap_class))[::-1][:top_k]
    idx_detect = np.argsort(np.abs(shap_detect))[::-1][:top_k]
    idx_separa = np.argsort(np.abs(shap_separa))[::-1][:top_k]
    idx_separa_detected = np.argsort(np.abs(shap_separa_detected))[::-1][:top_k]
    print(f'[class {i}] concepts class: {idx_class}') 
    print(f'[class {i}] conceptSHAP class: {shap_class[idx_class]}')
    print(f'[class {i}] concepts detect: {idx_detect}')
    print(f'[class {i}] conceptSHAP detect: {shap_detect[idx_detect]}')
    print(f'[class {i}] concepts separa: {idx_separa}')
    print(f'[class {i}] conceptSHAP separa: {shap_separa[idx_separa]}')
    print(f'[class {i}] concepts separa detected: {idx_separa_detected}')
    print(f'[class {i}] conceptSHAP separa detected: {shap_separa_detected[idx_separa_detected]}')
    print('==========================================\n')
    print(f'[class {i}] conceptSHAP detect with high separa detected: {shap_detect[idx_separa_detected]}')
    print(f'[class {i}] conceptSHAP separa detected with high detect: {shap_separa_detected[idx_detect]}')
    
    #concept_idx = np.intersect1d(idx_detect, idx_separa_detected)
    concept_idx = idx_detect
    tmp = []
    for i in list(concept_idx):
        if dupl_mapping[i] is not None:
            tmp.append(dupl_mapping[i])
    return np.array(tmp)

def plot_profile(in_scores, out_scores, concepts, figname):
    # plot average concept scores between ID vs OOD
    in_s_mean = np.mean(in_scores, axis=0)
    out_s_mean = np.mean(out_scores, axis=0)

    #if not concept_idx.any():
    #    concept_idx = np.argsort(np.abs(s_mean))[::-1][:top_k]

    num_types = 2
    num_concepts = len(concepts)
    bar_width = 0.35
    index = np.arange(num_concepts) * bar_width * (num_types + 1)

    separa = multivar_separa(in_scores[:,concepts], out_scores[:,concepts]).numpy()
    print(separa)

    fig, ax = plt.subplots(figsize=(3*len(concepts)/5,3))
    
    # ID
    bar = ax.bar(index + 0 * bar_width, in_s_mean[concepts],
            bar_width, yerr=np.std(in_scores[:,concepts],axis=0), color='tab:green')
    bar = ax.bar(index + 1 * bar_width, out_s_mean[concepts],
            bar_width, yerr=np.std(out_scores[:,concepts],axis=0), color='tab:orange')
    #ax.set_title('Top-{0} Prominent Concepts, Separability: {1:.3f}'.format(top_k, separa))#,fontsize=7)
    ax.set_title('Concept Separability: {0:.3f}'.format(separa),fontsize=9)
    ax.set_ylabel('Concept score')
    ax.set_ylim(-0.3,1)
    ax.set_xticks(index + num_types * bar_width / 2)
    ax.set_xticklabels(['C{}'.format(c) for c in concepts]) #, rotation=45)
    ax.legend(['Detected-ID', 'Detected-OOD'],fontsize=9)
    """
    # OOD
    bar = ax[1].bar(index + 0 * bar_width, out_s_mean[concepts],
            bar_width, yerr=np.std(out_scores[:,concepts],axis=0))
    #ax.set_title('Top-{0} Prominent Concepts, Separability: {1:.3f}'.format(top_k, separa))#,fontsize=7)
    ax[1].set_title('Detected-OOD with Concept Separability: {0:.3f}'.format(separa), fontsize=9)
    ax[1].set_ylabel('Concept score')
    ax[1].set_ylim(-0.3,1)
    ax[1].set_xticks(index + num_types * bar_width / 2)
    ax[1].set_xticklabels(['C{}'.format(c) for c in concepts]) #, rotation=45)
    """
    fig.tight_layout()
    plt.savefig(figname)
    plt.close()


def plot_with_profile(in_scores, out_scores, in_score, out_score, concepts, figname, flag='correct'):
    # plot average concept scores between ID vs OOD
    in_s_mean = np.mean(in_scores, axis=0)
    out_s_mean = np.mean(out_scores, axis=0)

    if len(concepts) > 5:
        concepts = concepts[:5]

    num_types = 3
    num_concepts = len(concepts)
    bar_width = 0.35
    index = np.arange(num_concepts) * bar_width * (num_types + 1)
    fig, ax = plt.subplots(figsize=(5*len(concepts)/5,3))
    bar = ax.bar(index + 0 * bar_width, in_s_mean[concepts],
            bar_width, yerr=np.std(in_scores[:,concepts],axis=0), color='tab:green', alpha=0.7)
    bar = ax.bar(index + 1 * bar_width, in_score[concepts], bar_width, color='tab:green')
    #bar = ax.bar(index + 2 * bar_width, out_s_mean[concepts],
    #        bar_width, yerr=np.std(out_scores[:,concepts],axis=0), color='tab:orange', alpha=0.7)
    bar = ax.bar(index + 2 * bar_width, out_score[concepts], bar_width, color='tab:orange')
    
    ax.set_title('Ours')
    ax.set_ylabel('Concept score')
    ax.set_ylim(-0.3,1)
    ax.set_xticks(index + num_types * bar_width / 2)
    ax.set_xticklabels(['C{}'.format(c) for c in concepts]) #, rotation=45)
    if flag == "correct":
        ax.legend(['ID Profile', 'ID -> ID', 'OOD -> OOD'],fontsize=9)
    elif flag == "wrong":
        ax.legend(['ID Profile', 'ID -> OOD', 'OOD -> ID'], fontsize=9)

    fig.tight_layout()
    plt.savefig(figname)
    plt.close()

def debug(in_score, out_score, in_concept, out_concept, thres, concept_idx, label,
        in_paths, out_paths):
    idx_IN_IN = np.where(in_score >= thres)[0]
    idx_IN_OUT = np.where(in_score < thres)[0]
    idx_OUT_IN = np.where(out_score >= thres)[0]
    idx_OUT_OUT = np.where(out_score < thres)[0] 

    #in_detect_concept = np.r_[in_concept[idx_IN_IN], out_concept[idx_OUT_IN]]
    #out_detect_concept = np.r_[in_concept[idx_IN_OUT], out_concept[idx_OUT_OUT]]

    plot_profile(in_concept[idx_IN_IN,:], out_concept[idx_OUT_OUT,:], concept_idx, 
        figname=os.path.join(FLAGS.logdir, f'{FLAGS.score}_{FLAGS.ood_data}_class{label}_detected.jpg'))
    
    def _plot(i1, i2):
        plot_with_profile(in_concept[idx_IN_IN,:], out_concept[idx_OUT_OUT,:], \
                            in_concept[idx_IN_IN[i1],:], out_concept[idx_OUT_OUT[i2],:], concept_idx, \
                            figname=os.path.join(FLAGS.logdir,'debug',f'{FLAGS.score}_{FLAGS.ood_data}',f'class{label}_correct_IN_{idx_IN_IN[i1]}_OUT_{idx_OUT_OUT[i2]}_plot.jpg'),
                            flag='correct')
        if i1 >= len(idx_IN_OUT):
            i1 = 0
        if i2 >= len(idx_OUT_IN):
            i2 = 0
        
        if len(idx_OUT_IN) >0 and len(idx_IN_OUT) > 0:
            plot_with_profile(in_concept[idx_IN_IN,:], out_concept[idx_OUT_OUT,:], \
                            in_concept[idx_IN_OUT[i1],:], out_concept[idx_OUT_IN[i2],:], concept_idx, \
                            figname=os.path.join(FLAGS.logdir,'debug',f'{FLAGS.score}_{FLAGS.ood_data}',f'class{label}_wrong_IN_{idx_IN_OUT[i1]}_OUT_{idx_OUT_IN[i2]}_plot.jpg'),
                            flag='wrong')
    _plot(1,1)
    _plot(1,2)
    _plot(0,0)
    _plot(2,1)
    _plot(2,2)
    _plot(3,1)
    _plot(3,2)
    _plot(5,5)

    print(len(idx_IN_OUT))
    print(len(idx_OUT_IN))
    for i in range(len(idx_IN_OUT)):
        if len(idx_IN_OUT) == 0:
            continue
        print(f'[IN_OUT] {i}th... {in_concept[idx_IN_OUT[i],concept_idx]}')
        img = Image.open(in_paths[idx_IN_OUT[i]]).resize((200,200), Image.ANTIALIAS)
        #plt.imshow(img)
        #plt.tight_layout()
        #plt.savefig(figname+'_'+str(count)+'.png')
        #plt.close()
        img.save(os.path.join(FLAGS.logdir,'debug',f'{FLAGS.score}_{FLAGS.ood_data}',f'class{label}_ID_OUT_{idx_IN_OUT[i]}.jpg'))

    for i in range(len(idx_OUT_IN)):
        if len(idx_OUT_IN) == 0:
            continue
        print(f'[OUT_IN] {i}th... {out_concept[idx_OUT_IN[i],concept_idx]}')
        img = Image.open(out_paths[idx_OUT_IN[i]]).resize((200,200), Image.ANTIALIAS)
        img.save(os.path.join(FLAGS.logdir,'debug',f'{FLAGS.score}_{FLAGS.ood_data}',f'class{label}_OUT_IN_{idx_OUT_IN[i]}.jpg'))

    count = 0
    for i in range(len(idx_IN_IN)):
        if count > 15:
            break
        print(f'[IN_IN] {i}th... {in_concept[idx_IN_IN[i],concept_idx]}')
        img = Image.open(in_paths[idx_IN_IN[i]]).resize((200,200), Image.ANTIALIAS)
        img.save(os.path.join(FLAGS.logdir,'debug',f'{FLAGS.score}_{FLAGS.ood_data}',f'class{label}_IN_IN_{idx_IN_IN[i]}.jpg'))
        count += 1

    count = 0
    for i in range(len(idx_OUT_OUT)):
        if count > 15:
            break
        print(f'[OUT_OUT] {i}th... {in_concept[idx_OUT_OUT[i],concept_idx]}')
        img = Image.open(out_paths[idx_OUT_OUT[i]]).resize((200,200), Image.ANTIALIAS)
        img.save(os.path.join(FLAGS.logdir,'debug',f'{FLAGS.score}_{FLAGS.ood_data}',f'class{label}_OUT_OUT_{idx_OUT_OUT[i]}.jpg'))
        count += 1


def main(_):
    ## Load models..................
    feature_model, predict_model = AwA2_helper.load_model_inception_new(_, _, input_size=(224,224), pretrain=True, n_gpus=1, modelname="results/Animals_with_Attributes2/inceptionv3_AwA2.h5")
    load_data(feature_model)

    N_CONCEPTS_ORIG = 100
    topic_model = ipca_v2.TopicModel(in_test_features, N_CONCEPTS_ORIG, thres=0.2, predict=predict_model)
    topic_model(in_test_features)
    topic_model.load_weights(FLAGS.logdir+'/topic_latest.h5', by_name=True)
    topic_vec_orig = topic_model.layers[0].get_weights()[0]

    topic_vec, dict_dupl_topic, dupl_mapping = remove_duplicate_concepts(topic_vec_orig, return_mapping=True)
    #N_CONCEPTS = np.shape(topic_vec)[-1] # 25



    shap_path = os.path.join(FLAGS.logdir,'SHAP_{}_{}.pkl'.format(FLAGS.score, FLAGS.ood_data))
    if not os.path.exists(shap_path):

        _, _, _, _, thres = get_measures(results['in_scores'][:,None], results['out_scores'][:,None])
        in_detect_concept, out_detect_concept, in_detect_yhat, out_detect_yhat = true_to_detected(results['in_scores'], results['out_scores'], 
                                                                            results['in_concepts'], results['out_concepts'], 
                                                                            thres, results['in_yhat'], results['out_yhat'])
        separa_all = compute_separability(in_detect_concept, out_detect_concept, in_detect_yhat, out_detect_yhat)
        print(separa_all)


        ## Compute ConceptSHAP......
        nc = np.shape(topic_vec_orig)[-1] # number of concepts before duplicate removal
        #inputs = list(itertools.product([0, 1], repeat=N_CONCEPTS_ORIG)) #NOTE: computationally very expensive
        inputs = np.ones((len(dict_dupl_topic),nc))
        for d in dict_dupl_topic:
            idx = [d] + dict_dupl_topic[d]
            inputs[d,idx] = 0
        inputs = np.unique([tuple(row) for row in inputs], axis=0)
        inputs = np.r_[np.ones((1,nc)), inputs]
        """ 
        inputs2 = np.zeros((len(dict_dupl_topic),nc))
        for d in dict_dupl_topic:
            idx = [d] + dict_dupl_topic[d]
            inputs2[d,idx] = 1
        inputs2 = np.unique([tuple(row) for row in inputs2], axis=0)
        inputs = np.r_[np.ones((1,nc)), inputs, inputs2]
        """
        #inputs = inputs[:2]
        
        global shap_expl
        shap_expl = {'mask': inputs}

        for label in classes:
            shap_expl = kernelSHAP(nc, inputs, topic_vec_orig, topic_model, feature_model, label)
        
        with open(shap_path,'wb') as f:
            pickle.dump(shap_expl, f)

    else:
        with open(shap_path, 'rb') as f:
            shap_expl = pickle.load(f)
    _, _, _, _, thres = get_measures(results['in_scores'][:,None], results['out_scores'][:,None])
    """
    in_detect_concept, out_detect_concept, in_detect_yhat, out_detect_yhat = true_to_detected(results['in_scores'], results['out_scores'], 
                                                                            results['in_concepts'], results['out_concepts'], 
                                                                            thres, results['in_yhat'], results['out_yhat'])
    separa_all = compute_separability(in_detect_concept, out_detect_concept, in_detect_yhat, out_detect_yhat)
    print(separa_all)
    """
   
    print(dict_dupl_topic)

    if not os.path.exists(os.path.join(FLAGS.logdir,'debug')):
        os.mkdir(os.path.join(FLAGS.logdir,'debug'))
    if not os.path.exists(os.path.join(FLAGS.logdir,'debug',f'{FLAGS.score}_{FLAGS.ood_data}')):
        os.mkdir(os.path.join(FLAGS.logdir,'debug',f'{FLAGS.score}_{FLAGS.ood_data}'))

    for label in classes:
        sys.stdout = open(os.path.join(FLAGS.logdir,'debug', f'{FLAGS.score}_{FLAGS.ood_data}', '{}_{}_class{}_debug.txt'.format(FLAGS.score, FLAGS.ood_data, label)), 'w')
        concept_idx = load_shap(shap_expl, label, dupl_mapping, top_k=15)
        idx_in = np.where(results['in_yhat']==label)[0]
        idx_out = np.where(results['out_yhat']==label)[0]
        in_score = results['in_concepts'][idx_in,:]
        out_score = results['out_concepts'][idx_out,:]
        #plot(in_score, out_score, [53,7,25,56,10,39,43], 
        #        figname=os.path.join(FLAGS.logdir, f'{FLAGS.score}_{FLAGS.ood_data}_class{label}.jpg'))
        debug(results['in_scores'][idx_in], results['out_scores'][idx_out], 
            results['in_concepts'][idx_in,:], results['out_concepts'][idx_out,:], 
            thres, concept_idx, label,
            np.array(in_test_filepaths)[idx_in], np.array(out_test_filepaths)[idx_out])
        sys.stdout.close()

if __name__ == '__main__':
    flags.DEFINE_string('ood_data', 'Places', 'OOD Dataset.')
    flags.DEFINE_string('logdir', 'results/AwA2_coher10_energy1_feat0.1_separaEnergy_multiv_coeff10_concept100/', 'Directory where the results to be saved.')
    flags.DEFINE_string('score', 'Energy', 'OOD detection method') 
    flags.DEFINE_integer('temperature_odin', 1000, 'temperature scaling for odin')
    flags.DEFINE_float('epsilon_odin', 0.0, 'perturbation magnitude for odin')
    flags.DEFINE_integer('temperature_energy', 1, 'temperature scaling for energy')

    #flags.DEFINE_bool('random_labels', False, 'use random labels.')

    app.run(main)
