import numpy as np
import os
import sys
import pickle
from scipy.stats import norm
from scipy.stats import t as tdist
from absl import app

from AwA2_eval import compute_completeness, compute_detection_completeness, compute_separability, compute_coherency, remove_duplicate_concepts
from utils.test_utils import get_measures

logdir = sys.argv[1]
score = sys.argv[2]
ood_data = sys.argv[3] 
n_trials = 200
conf = 0.95

def compute_CI(X):
    # X: (n_trials,)-dim array of observations from bootstrapping

    avg_X = np.mean(X, axis=0)

    # Standard error of the mean
    z_val = norm.ppf((1 + conf) / 2.)
    t_val = tdist.ppf((1 + conf) / 2., df=(n_trials-1))
    if n_trials >= 100:
        c = z_val / np.sqrt(n_trials)
    else:
        # use the t-distribution for small `n_trials`
        c = t_val / np.sqrt(n_trials)

    std_X = c * np.std(X, axis=0, ddof=1)

    return avg_X, std_X

def bootstrap_coherency(ood_data):
    # load in features 
    feat = np.load('results/Animals_with_Attributes2/f_AwA2.npy')
    # load concept vectors
    topic_vec_orig = np.load(os.path.join(logdir, 'topic_vec_orig.npy'))
    topic_vec, _ = remove_duplicate_concepts(topic_vec_orig)
    

    rng = np.random.RandomState(seed=12345)
    idx = np.arange(feat.shape[0])
    comp = []

    for i in range(n_trials):
        pred_idx = rng.choice(idx, size=len(idx), replace=True)
        comp.append(compute_coherency(feat[pred_idx], topic_vec))
    
    avg_comp, std_comp = compute_CI(comp)
    return avg_comp, std_comp


def bootstrap_comp(y, yhat, yhat_recov):
    rng = np.random.RandomState(seed=12345)
    idx = np.arange(len(y))
    
    comp = []
    
    for i in range(n_trials):
        pred_idx = rng.choice(idx, size=len(idx), replace=True)
        comp.append(compute_completeness(y[pred_idx], yhat[pred_idx], yhat_recov[pred_idx], num_class=50))
    
    avg_comp, std_comp = compute_CI(comp)
    return avg_comp, std_comp

def bootstrap_comp_detect(s_id, s_ood, s_id_recov, s_ood_recov):
    rng = np.random.RandomState(seed=12345)
    idx_id = np.arange(len(s_id))
    idx_ood = np.arange(len(s_ood))

    comp = []
    for i in range(n_trials):
        pred_idx_id = rng.choice(idx_id, size=len(idx_id), replace=True)
        pred_idx_ood = rng.choice(idx_ood, size=len(idx_ood), replace=True)

        auroc, _, _, _, _ = get_measures(s_id[pred_idx_id,None],s_ood[pred_idx_ood,None])
        auroc_recov, _, _, _, _ = get_measures(s_id_recov[pred_idx_id,None],s_ood_recov[pred_idx_ood,None])
        comp.append(compute_detection_completeness(auroc, auroc_recov))

    avg_comp, std_comp = compute_CI(comp)
    return avg_comp, std_comp


def bootstrap_separa(concept_id, concept_ood, yhat_id, yhat_ood, num_class=50):
    rng = np.random.RandomState(seed=12345)
    idx_id = np.arange(len(yhat_id))
    idx_ood = np.arange(len(yhat_ood))

    separa_base = np.load('results/AwA2_baseline_concept100/separability_{}_AwA2_{}_multiv.npy'.format(score,ood_data), allow_pickle=True).item()
    separa = []
    
    for i in range(n_trials):
        pred_idx_id = rng.choice(idx_id, size=len(idx_id), replace=True)
        pred_idx_ood = rng.choice(idx_ood, size=len(idx_ood), replace=True)

        separa_target = compute_separability(concept_id[pred_idx_id,:], concept_ood[pred_idx_ood,:], 
                                            yhat_id[pred_idx_id], yhat_ood[pred_idx_ood])

        # compute relative separability
        separa_relative = []
        for i in range(num_class):
            a = separa_base['class'+str(i)]
            b = separa_target['class'+str(i)]
            if a and b:
                separa_relative.append((b-a)/a)
        separa.append(np.median(separa_relative))

    avg_separa, std_separa = compute_CI(separa)
    return avg_separa, std_separa

def main(argv):
    
    ## Load necessary data and ConceptSHAP results.............
    shap_path = os.path.join(logdir,'results_{}_{}.pkl'.format(score, ood_data))
    with open(shap_path,'rb') as f:
        shap_expl = pickle.load(f)
    in_test_yhat = shap_expl['in_yhat']
    out_test_yhat = shap_expl['out_yhat']
    in_test_yhat_recov = shap_expl['in_yhat_recov']
    in_test_concepts = shap_expl['in_concepts']
    out_test_concepts = shap_expl['out_concepts']
    in_test_scores = shap_expl['in_scores']
    out_test_scores = shap_expl['out_scores']
    in_test_scores_recov = shap_expl['in_scores_recov']
    out_test_scores_recov = shap_expl['out_scores_recov']
    N_CONCEPTS = in_test_concepts.shape[1]
    
    y_test = np.argmax(np.load('/nobackup/jihye/data/Animals_with_Attributes2/y_test.npy'), axis=1) # true labels
    sys.stdout = open(os.path.join(logdir, 'bootstrap_{}_{}.txt'.format(score, ood_data)), 'w')
   
    avg_coher, std_coher = bootstrap_coherency(ood_data)
    print('CI {} | num_trials {}'.format(conf, n_trials))
    print('Coherency: avg: {} | std: {}'.format(avg_coher, std_coher))

    avg_comp, std_comp = bootstrap_comp(y_test, in_test_yhat, in_test_yhat_recov)
    print('Classification Completeness: avg: {} | std: {}'.format(avg_comp, std_comp))
    
    avg_comp_detect, std_comp_detect = bootstrap_comp_detect(in_test_scores, out_test_scores, in_test_scores_recov, out_test_scores_recov)
    print('Detection Completeness: avg: {} | std: {}'.format(avg_comp_detect, std_comp_detect))
    
    avg_separa, std_separa = bootstrap_separa(in_test_concepts, out_test_concepts, in_test_yhat, out_test_yhat, num_class=50)
    print('Relative Concept Separability: avg: {} | std: {}'.format(avg_separa, std_separa))

    sys.stdout.close()


if __name__ == '__main__':
    app.run(main)
