# import
import torch
import numpy as np
from scipy import stats

def create_epoch_metrics():
    epoch_metrics = {
            'log_ll_ebm': 0,
            'log_ll_e': 0,
            'log_ll_s': 0,
            'log_ll_err': 0,
            'log_ll_err_var': 0,
            'kendalltau': 0,
            'kendalltau_pvalue': 0,
            'kendalltau_self': 0,
            'kendalltau_self_pvalue': 0,
            'spearman': 0,
            'spearman_pvalue': 0,
            'spearman_self': 0,
            'spearman_self_pvalue': 0,
            'pearson': 0,
            'pearson_pvalue': 0,
            'pearson_self': 0,
            'pearson_self_pvalue': 0,
            'count': 0,
        }
    return epoch_metrics

def update_epoch_metrics(epoch_metrics, cmp_results, count):
    epoch_metrics['kendalltau'] += cmp_results['kendalltau'].statistic * count
    epoch_metrics['kendalltau_pvalue'] += cmp_results['kendalltau'].pvalue * count
    epoch_metrics['kendalltau_self'] += cmp_results['kendalltau_self'].statistic * count
    epoch_metrics['kendalltau_self_pvalue'] += cmp_results['kendalltau_self'].pvalue * count
    epoch_metrics['spearman'] += cmp_results['spearman'].statistic * count
    epoch_metrics['spearman_pvalue'] += cmp_results['spearman'].pvalue * count
    epoch_metrics['spearman_self'] += cmp_results['spearman_self'].statistic * count
    epoch_metrics['spearman_self_pvalue'] += cmp_results['spearman_self'].pvalue * count
    epoch_metrics['pearson'] += cmp_results['pearson'].statistic * count
    epoch_metrics['pearson_pvalue'] += cmp_results['pearson'].pvalue * count
    epoch_metrics['pearson_self'] += cmp_results['pearson_self'].statistic * count
    epoch_metrics['pearson_self_pvalue'] += cmp_results['pearson_self'].pvalue * count
    epoch_metrics['log_ll_ebm'] += cmp_results['logp_ebm'] * count
    epoch_metrics['log_ll_e'] += cmp_results['logp'] * count
    epoch_metrics['log_ll_s'] += cmp_results['logp_2'] * count
    epoch_metrics['log_ll_err'] += cmp_results['logp_err'] * count
    epoch_metrics['log_ll_err_var'] += cmp_results['logp_err_var'] * count
    epoch_metrics['count'] += count
    return

def preprocess_logp(logp, threshold=1.0):
    logp_sorted, indices_arm = torch.sort(logp, dim=-1, descending=True)
    order_arm = torch.argsort(indices_arm, dim=-1)
    logp_sorted_list = logp_sorted.cpu().numpy().tolist()
    select_ordered_ind = [0]
    prev_logp = logp_sorted[0]
    for i in range(1, logp.shape[0]):
        diff =  prev_logp - logp_sorted[i]
        if diff > threshold:
            select_ordered_ind.append(i)
            prev_logp = logp_sorted[i]
    selected_ind = indices_arm[select_ordered_ind]
    return selected_ind

def compare_logp(logp, logp_2, logp_ebm, log_z):
    logp_ebm_sorted, indices_ebm = torch.sort(logp_ebm, dim=-1)
    ranking_emb = torch.argsort(indices_ebm, dim=-1)
    logp_sorted, indices_arm = torch.sort(logp, dim=-1)
    ranking_arm = torch.argsort(indices_arm, dim=-1)
    logp_2_sorted, indices_arm_2 = torch.sort(logp_2, dim=-1)
    ranking_arm_2 = torch.argsort(indices_arm_2, dim=-1)

    pearson = stats.pearsonr(logp_ebm.cpu().numpy(), logp.cpu().numpy())
    pearson_self = stats.pearsonr(logp.cpu().numpy(), logp_2.cpu().numpy())
    spearman = stats.spearmanr(logp_ebm.cpu().numpy(), logp.cpu().numpy())
    spearman_self = stats.spearmanr(logp.cpu().numpy(), logp_2.cpu().numpy())
    kendalltau = stats.kendalltau(ranking_emb.cpu().numpy(), ranking_arm.cpu().numpy())
    kendalltau_self = stats.kendalltau(ranking_arm.cpu().numpy(), ranking_arm_2.cpu().numpy())
    
    logp_err = (logp - logp_ebm + log_z).abs().mean()
    logp_err_var = (logp - logp_ebm + log_z).var()

    # put results into a dictionary
    results = {
        'logp_ebm': logp_ebm.mean().item(),
        'logp': logp.mean().item(),
        'logp_2': logp_2.mean().item(),
        'pearson': pearson,
        'pearson_self': pearson_self,
        'spearman': spearman,
        'spearman_self': spearman_self,
        'kendalltau': kendalltau,
        'kendalltau_self': kendalltau_self,
        'logp_err': logp_err,
        'logp_err_var': logp_err_var}

    return results