import numpy as np
from collections import OrderedDict
import torch.nn.functional as F 
import warnings
warnings.filterwarnings("ignore")
from sklearn import metrics

def compute_metrics(all_predictions, all_targets, verbose=True):
    
    all_predictions = F.sigmoid(all_predictions)
        

    meanAP = metrics.average_precision_score(all_targets, all_predictions, 'macro', pos_label=1)
    
    optimal_threshold = 0.5 

    all_targets = all_targets.numpy()
    all_predictions = all_predictions.numpy()

    top_3rd = np.sort(all_predictions)[:,-3].reshape(-1,1)
    all_predictions_top3 = all_predictions.copy()
    all_predictions_top3[all_predictions_top3<top_3rd] = 0
    all_predictions_top3[all_predictions_top3<optimal_threshold] = 0
    all_predictions_top3[all_predictions_top3>=optimal_threshold] = 1

    CP_top3 = metrics.precision_score(all_targets, all_predictions_top3, average='macro')
    CR_top3 = metrics.recall_score(all_targets, all_predictions_top3, average='macro')
    CF1_top3 = (2*CP_top3*CR_top3)/(CP_top3+CR_top3)
    OP_top3 = metrics.precision_score(all_targets, all_predictions_top3, average='micro')
    OR_top3 = metrics.recall_score(all_targets, all_predictions_top3, average='micro')
    OF1_top3 = (2*OP_top3*OR_top3)/(OP_top3+OR_top3)

    
    all_predictions_thresh = all_predictions.copy()
    all_predictions_thresh[all_predictions_thresh <= optimal_threshold] = 0
    all_predictions_thresh[all_predictions_thresh > optimal_threshold] = 1
    CP = metrics.precision_score(all_targets, all_predictions_thresh, average='macro')
    CR = metrics.recall_score(all_targets, all_predictions_thresh, average='macro')
    CF1 = (2*CP*CR)/(CP+CR)
    OP = metrics.precision_score(all_targets, all_predictions_thresh, average='micro')
    OR = metrics.recall_score(all_targets, all_predictions_thresh, average='micro')
    OF1 = (2*OP*OR)/(OP+OR)  

    eval_ret = OrderedDict([('Label-based Micro F1', OF1), ('Label-based Macro F1', CF1)])

    
    OF1 = eval_ret['Label-based Micro F1']
    CF1 = eval_ret['Label-based Macro F1']

    if verbose:
        print('----')
        print('mAP:   {:0.1f}'.format(meanAP*100))
        print('----')
        print('CP:    {:0.1f}'.format(CP*100))
        print('CR:    {:0.1f}'.format(CR*100))
        print('CF1:   {:0.1f}'.format(CF1*100))
        print('OP:    {:0.1f}'.format(OP*100))
        print('OR:    {:0.1f}'.format(OR*100))
        print('OF1:   {:0.1f}'.format(OF1*100))
        print('----')
        print('CP_t3: {:0.1f}'.format(CP_top3*100))
        print('CR_t3: {:0.1f}'.format(CR_top3*100))
        print('CF1_t3:{:0.1f}'.format(CF1_top3*100))
        print('OP_t3: {:0.1f}'.format(OP_top3*100))
        print('OR_t3: {:0.1f}'.format(OR_top3*100))
        print('OF1_t3:{:0.1f}'.format(OF1_top3*100)) 

    metrics_dict = {}
    metrics_dict['mAP'] = meanAP
    metrics_dict['OF1'] = OF1
    metrics_dict['CF1'] = CF1

    return metrics_dict