import sys
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')

import numpy as np


def bold(string):
    return '\033[1m' + string + '\033[0m'

def underline(string):
    return '\033[4m' + string + '\033[0m'


# all: {'loss': 0.5110480383345128, 'f1': 0.7938357255170856, 'precision': 0.7936523699814242, 'recall': 0.7956469458464581}
# kbest: {'loss': 0.5232750598540521, 'f1': 0.7987335656651726, 'precision': 0.7986393293550187, 'recall': 0.8008241817577294}
# ours:  {'loss': 0.5096874166470566, 'f1': 0.7997539011108926, 'precision': 0.7990909918348529, 'recall': 0.8028970435673325}

_last_id_sort = None
def forest_compare(prob, selected_n=27):
    global _last_id_sort
    # 50 max_epoch # {'loss': 0.6452664123115505, 'f1': 0.7015734125230068, 'precision': 0.7051620352040737, 'recall': 0.7036517740534176}
    # 500 max_epoch # {'loss': 0.5137106735344706, 'f1': 0.776988913496114, 'precision': 0.7803794652210887, 'recall': 0.7774691071387122}
    best_k_ids = [int(i) for i in '0  2  3  5  6  8  9 10 11 12 13 15 16 17 19 23 25 26 30 35 36 42 43 45 51 52 53'.split()]
    #  10/10 {'loss': 0.509678203498245, 'f1': 0.779998614667319, 'precision': 0.7847522136409036, 'recall': 0.780007783707176}
    lasso_k_ids = [10, 51, 7, 5, 0, 4]
    
#     sorted_prob = , key=lambda x: x[0], reverse=True)
#     selected_ids = [item[1] for item in sorted_prob[:selected_n]]
    result = []
    probs_and_ids_sort = sorted(zip(prob, np.arange(len(prob))), reverse=True)
    ids_sort = [item[1] for item in probs_and_ids_sort]
    if _last_id_sort is None:
        _last_id_sort = ids_sort
        
    for i, (p, id_) in enumerate(probs_and_ids_sort):
        p = str('{:.3f}'.format(float(p)))
        
        last_ranking = _last_id_sort.index(id_)
        diff_rank = last_ranking - i
        if diff_rank > 0:
            p += '\033[1;41m\u2191\033[0m' + str(diff_rank)
        elif diff_rank < 0:
            p += '\033[1;43m\u2193\033[0m' + str(-diff_rank)
        if id_ in best_k_ids:
            p = bold(p)
        if id_ in lasso_k_ids:
            p = underline(p)
        if i == selected_n:
            result.append('|')
        result.append(p)

    _last_id_sort = ids_sort
    print(' '.join(result))

    sorted_prob = sorted(zip(prob, np.arange(len(prob))), key=lambda x: x[0], reverse=True)
    selected_ids = [item[1] for item in sorted_prob[:selected_n]]
    len(set(best_k_ids) & set(selected_ids))

    return {
        'in_k_best': len(set(best_k_ids) & set(selected_ids)),
        'in_lasso': len(set(lasso_k_ids) & set(selected_ids)),
    }


