############## PARAMETERS TO ADJUST ##############
GZSL = False
DSET = 'Wiki'
NUM_LBLS = 4271 if DSET == 'Eurlex' else 13330 if DSET == 'Amzn' else 1285321
# NUM_LBLS = 1057 if DSET == 'Eurlex' else 6500 if DSET == 'Amzn' else 1285321
##################################################

from tqdm import tqdm
import os
os.chdir('predictions')
from glob import glob

res_files = glob(f'results{NUM_LBLS}*')
res_files.sort(key=os.path.getmtime, reverse= True)
for res_file in res_files:
    print(res_file)
    qfile = glob('_'.join(res_file.replace('results','qrels').split('_')[:-1])+'*')[0]
    res_file, qfile
    import pickle
    print('Loading Files')
    outp = pickle.load(open(res_file,'rb'))
    print(len(outp))
    print('Load Outp File')
    lbls = pickle.load(open(qfile,'rb'))
    # print(outp.shape, lbls.shape)
    print('Load All Files')

    if DSET == 'Amzn':
        unseen_labels = {x.strip() for x in open('../SemSup-LMLC/training/datasets/Amzn13K/unseen_labels_split6500_2.txt').readlines()}
        all_labels = {x.strip() for x in open('../SemSup-LMLC/training/datasets/Amzn13K/all_labels.txt').readlines()}
    elif DSET == 'Eurlex':
        unseen_labels = {x.strip() for x in open('../SemSup-LMLC/training/datasets/eurlex4.3k/unseen_labels_split1057.txt').readlines()}
        all_labels = {x.strip() for x in open('../SemSup-LMLC/training/datasets/eurlex4.3k/all_labels.txt').readlines()}
    elif DSET == 'Wiki':
        unseen_labels = {x.strip() for x in open('../SemSup-LMLC/training/datasets/Wiki1M/unseen_labels.txt').readlines()}
        all_labels = {x.strip() for x in open('../SemSup-LMLC/training/datasets/Wiki1M/all_labels.txt').readlines()}

    if GZSL:
        unseen_labels = all_labels


    new_lbls = dict()
    new_outp = dict()
    for k in tqdm(outp):
        new_outp[k] = dict(sorted({k:v for k,v in outp[k].items() if k in unseen_labels}.items(), key = lambda x: -x[1]))
        new_lbls[k] = unseen_labels.intersection(lbls[k])


    top_vals = [1,3,5,10,20]
    size = 0
    scores = {k:0 for k in top_vals}
    recall_scores = {k:0 for k in top_vals}
    comp_len_sum = 0
    for k in tqdm(new_outp):
        o = new_outp[k]
        l = new_lbls[k]
        if len(l) == 0:
            continue
        size += 1
        for val in top_vals:
            to_add = len(set(list(new_outp[k].keys())[:val]).intersection(l))
            scores[val] += to_add
            recall_scores[val] += to_add/len(l)
        comp_len_sum += len(l)
    print('Total Labels for recall', comp_len_sum)
    print(size)
    final_scores = dict()
    for k,v in scores.items():
        if k in [1,3,5]:
            final_scores[f'P@{k}'] = v/(size * k)
    for k,v in recall_scores.items():
        if k in [5, 10, 20]:
            final_scores[f'R@{k}'] = v / size
    print(' '.join([str(x) for x in final_scores.values()]))