import numpy as np
from scipy.stats import spearmanr, kendalltau, rankdata

# Calculate Kendall's and Spearman's coefficients
def get_corr_coeff(pred_imp_scores, videos, dataset, user_scores=None, save_scores = {}):
    rho_coeff, tau_coeff = [], []
    if dataset=='SumMe':
        for pred_imp_score,video in zip(pred_imp_scores,videos):
            true = np.mean(user_scores,axis=0)
            tmp_rho_coeff = spearmanr(pred_imp_score,true)[0]
            tmp_tau_coeff = kendalltau(rankdata(pred_imp_score),rankdata(true))[0]
            rho_coeff.append(tmp_rho_coeff)
            tau_coeff.append(tmp_tau_coeff)

            selection_rho_coeff = np.mean(tmp_rho_coeff).item()
            selection_tau_coeff = np.mean(tmp_tau_coeff).item()
            save_scores.setdefault(video,{})
            if ('rho_coeff' not in save_scores[video] or 'tau_coeff' not in save_scores[video]) or (save_scores[video]['rho_coeff'] < selection_rho_coeff and save_scores[video]['tau_coeff'] < selection_tau_coeff):
                save_scores[video] = {
                    'rho_coeff': selection_rho_coeff,
                    'tau_coeff': selection_tau_coeff,
                    'pred_imp_score': np.squeeze(pred_imp_score).tolist(),
                    'user_scores': user_scores.tolist(),
                }
    elif dataset=='TVSum':
        for pred_imp_score,video in zip(pred_imp_scores,videos):
            pred_imp_score = np.squeeze(pred_imp_score).tolist()
            user = int(video.split("_")[-1])
            curr_user_score = user_scores[user-1]
            tmp_rho_coeff, tmp_tau_coeff = [], []
            for annotation in range(len(curr_user_score)):
                true_user_score = curr_user_score[annotation]
                curr_rho_coeff, _ = spearmanr(pred_imp_score, true_user_score)
                curr_tau_coeff, _ = kendalltau(rankdata(pred_imp_score), rankdata(true_user_score))
                tmp_rho_coeff.append(curr_rho_coeff)
                tmp_tau_coeff.append(curr_tau_coeff)
            rho_coeff.append(np.mean(tmp_rho_coeff))
            tau_coeff.append(np.mean(tmp_tau_coeff))

            selection_rho_coeff = np.mean(tmp_rho_coeff).item()
            selection_tau_coeff = np.mean(tmp_tau_coeff).item()
            save_scores.setdefault(video,{})
            if ('rho_coeff' not in save_scores[video] or 'tau_coeff' not in save_scores[video]) or (save_scores[video]['rho_coeff'] < selection_rho_coeff and save_scores[video]['tau_coeff'] < selection_tau_coeff):
                save_scores[video] = {
                    'rho_coeff': selection_rho_coeff,
                    'tau_coeff': selection_tau_coeff,
                    'pred_imp_score': pred_imp_score,
                    'user_scores': [e.tolist() for e in curr_user_score],
                }

    elif dataset=='QFVS' or dataset=='QFVS2': 
        for pred_imp_score,video in zip(pred_imp_scores,videos):
            true = user_scores
            rho_coeff.append(spearmanr(pred_imp_score,true)[0])
            tau_coeff.append(kendalltau(rankdata(pred_imp_score),rankdata(true))[0])
    rho_coeff = np.array(rho_coeff).mean()
    tau_coeff = np.array(tau_coeff).mean()

    return rho_coeff, tau_coeff

def f1_score(pred: np.ndarray, test: np.ndarray) -> float:
    """Compute F1-score on binary classification task.

    :param pred: Predicted binary label. Sized [N].
    :param test: Ground truth binary label. Sized [N].
    :return: F1-score value.
    """
    assert pred.shape == test.shape
    pred = np.asarray(pred, dtype=bool)
    test = np.asarray(test, dtype=bool)
    overlap = (pred & test).sum()
    if overlap == 0:
        return 0.0
    precision = overlap / pred.sum()
    recall = overlap / test.sum()
    f1 = 2 * precision * recall / (precision + recall)
    return float(f1)


def get_summ_f1score(pred_summ: np.ndarray,
                     test_summ: np.ndarray,
                     eval_metric: str = 'avg'
                     ) -> float:
    """Compare predicted summary with ground truth summary (keyshot-based).

    :param pred_summ: Predicted binary label of N frames. Sized [N].
    :param test_summ: Ground truth binary labels of U users. Sized [U, N].
    :param eval_metric: Evaluation method. Choose from (max, avg).
    :return: F1-score value.
    """
    pred_summ = np.asarray(pred_summ, dtype=bool)
    test_summ = np.asarray(test_summ, dtype=bool)
    _, n_frames = test_summ.shape

    if pred_summ.size > n_frames:
        pred_summ = pred_summ[:n_frames]
    elif pred_summ.size < n_frames:
        pred_summ = np.pad(pred_summ, (0, n_frames - pred_summ.size))

    f1s = [f1_score(user_summ, pred_summ) for user_summ in test_summ]

    if eval_metric == 'avg':
        final_f1 = np.mean(f1s)
    elif eval_metric == 'max':
        final_f1 = np.max(f1s)
    else:
        raise ValueError(f'Invalid eval metric {eval_metric}')

    return float(final_f1)

def get_summ_diversity(pred_summ: np.ndarray,
                       features: np.ndarray
                       ) -> float:
    """Evaluate diversity of the generated summary.

    :param pred_summ: Predicted down-sampled summary. Sized [N, F].
    :param features: Normalized down-sampled video features. Sized [N, F].
    :return: Diversity value.
    """
    assert len(pred_summ) == len(features)
    pred_summ = np.asarray(pred_summ, dtype=bool)
    pos_features = features[pred_summ]

    if len(pos_features) < 2:
        return 0.0

    diversity = 0.0
    for feat in pos_features:
        diversity += (feat * pos_features).sum() - (feat * feat).sum()

    diversity /= len(pos_features) * (len(pos_features) - 1)
    return diversity