import torch


def get_mia(string_scores):
    # member: high
    # nonmember: low
    mia_scores = {}
    
    string_scores = {k: v[:,-string_scores[k.replace('recall_', '')].shape[1]:].to(torch.float64) for k, v in string_scores.items()}
    for prefix in ['', 'recall_']:
        log_prob = string_scores[f'{prefix}token_log_probs']
        mia_scores[f'{prefix}prob'] = log_prob.mean(-1)
        mia_scores[f'{prefix}neg_hinge'] = -string_scores[f'{prefix}hinge'].mean(-1)

        best_k = log_prob.sort(axis=-1).values
        for k in [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]:
            mia_scores[f'{prefix}mink_{k}'] = best_k[:, :int(k * log_prob.shape[1])].mean(-1)

        best_k = log_prob.sort(axis=-1, descending=True).values
        for k in [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]:
            mia_scores[f'{prefix}maxk_{k}'] = best_k[:, :int(k * log_prob.shape[1])].mean(-1)

        minkplus = (log_prob - string_scores[f'{prefix}mu']) / string_scores[f'{prefix}sigma'].sqrt()
        best_k = minkplus.sort(axis=-1).values
        for k in [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]:
            mia_scores[f'{prefix}mink++_{k}'] = best_k[:, :int(k * minkplus.shape[1])].mean(-1)

        best_k = minkplus.sort(axis=-1, descending=True).values
        for k in [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]:
            mia_scores[f'{prefix}maxk++_{k}'] = best_k[:, :int(k * minkplus.shape[1])].mean(-1)

    # ReCall
    base_metrics = list(mia_scores.keys())
    for metric in base_metrics:
        if metric.startswith('recall_'):
            base_metric = metric[len('recall_'):]
            if base_metric in base_metrics:
                mia_scores[f'real_recall_{base_metric}'] = mia_scores[metric] / (mia_scores[base_metric] + 1e-6)
    return mia_scores