import tensorflow as tf
from scipy import stats
from skimage.metrics import structural_similarity as ssim
import numpy as np

#top k


def top_k_intersection(arr1,
                       arr2,
                       k,
                       axis=-1,
                       return_ranks=False,
                       return_values=False):
    #what if it has multiple axes because its vectorized
    #     topk1=np.argsort(arr1, axis=axis)[-(k+1):-1]
    #     topk2=np.argsort(arr2, axis=axis)[-(k+1):-1]
    vals1, topk1 = tf.math.top_k(arr1, k)
    vals2, topk2 = tf.math.top_k(arr2, k)

    inter = np.intersect1d(topk1, topk2)
    if return_ranks:
        topk1, topk2, inter
    elif return_values:
        topk1, topk2, vals1, vals2, inter
    else:
        return inter.shape[0] / k


#spearman's rank
#returns corr, p-value
#0 is no relationship, goes from -1 to +1
#does not assume data is gaussian
#stats.spearmanr(a, b=None, axis=0)

#pearson's returns r, coef, and p-value
#0 is no relationship, goes from -1 to +1
#assumes data is gaussian
#stats.pearsonr(x, y)

#SSIM
#ssim(img1, img2, data_range=img.max() - img.min())


#MSE
def mse(A, B, ax):
    return ((A - B)**2).mean(axis=ax)


def get_metrics(expls1, expls2, k, SSIM=False):
    sp_rk = np.array([
        stats.spearmanr(expl1, expl2) for expl1, expl2 in zip(expls1, expls2)
    ])
    pr_rk = np.array(
        [stats.pearsonr(expl1, expl2) for expl1, expl2 in zip(expls1, expls2)])
    mse = np.array([mse(expl1, expl2) for expl1, expl2 in zip(expls1, expls2)])
    topk = np.array([
        top_k_intersection(expl1, expl2, k)
        for expl1, expl2 in zip(expls1, expls2)
    ])
    l2 = np.array([
        np.linalg.norm(np.abs(expl1 - expl2))
        for expl1, expl2 in zip(expls1, expls2)
    ])

    if SSIM:
        SSIM = np.array([
            ssim(expl1,
                 expl2,
                 data_range=max(expl1.max(), expl2.max()) -
                 min(expl1.min(), expl2.min()))
            for expl1, expl2 in zip(expls1, expls2)
        ])

    if SSIM:
        return SSIM, sp_rk, pr_rk, mse, topk, l2
    else:
        return sp_rk, pr_rk, mse, topk, l2


def batch_flatten(x):
    return np.reshape(x, (x.shape[0], -1))


def convert_to_super_labels(preds, affinity_set):
    for subset in affinity_set:
        for l in subset:
            preds[preds == l] = subset[0]
    return preds

def invalidation(counterfactuals,
                 modelA_counterfual_preds,
                 modelB,
                 modelA_pred=None,
                 batch_size=512,
                 aggregation=None,
                 return_pred_B=False, 
                 affinity_set=[[0], [1, 2]]):

    if aggregation == 'mean':
        aggregation = np.mean

    if modelA_pred is None:
        modelA_pred = 1 - modelA_counterfual_preds

    modelB_counterfactual_probits = modelB.predict(counterfactuals,
                                                   batch_size=batch_size)

    is_bianary = modelB_counterfactual_probits.shape[1] <= 2

    modelB_counterfactual_pred = np.argmax(modelB_counterfactual_probits,
                                           axis=-1)
    
    if is_bianary:
        validation = np.mean(
            modelA_pred != modelB_counterfactual_pred)
    else:
        modelA_super_labels = convert_to_super_labels(modelA_pred.copy(), affinity_set)
        modelB_counterfactual_super_labels = convert_to_super_labels(modelB_counterfactual_pred.copy(), affinity_set)

        validation = np.mean(
            modelA_super_labels != modelB_counterfactual_super_labels)

    # confidence of model B's prediction on the counterfactuals
    modelB_counterfactual_counfidence = np.max(modelB_counterfactual_probits,
                                             axis=-1)[modelA_pred != modelB_counterfactual_pred]

    if aggregation is not None:
        modelB_counterfactual_counfidence = aggregation(
            modelB_counterfactual_counfidence)

    if not return_pred_B:
        return 1.0 - validation, modelB_counterfactual_counfidence

    else:
        return 1.0 - validation, (modelB_counterfactual_counfidence,
                                  modelB_counterfactual_pred)
