import pandas as pd
import argparse
import os
import heapq
import numpy as np
import pickle

def evaluate_sequence(data, top_k):
    hits = 0
    for i in range(len(data)):
        gt = data['gt'][i]
        preds = data['preds'][i][:top_k]
        if gt in preds:
            hits += 1
    return hits/len(data)

def eval_cols(data, cols):
    for col in cols:
        data[col] = data[col].apply(lambda x: eval(x))
    return data

def eval_topk(topk_, preds, labels):
    assert len(preds) == len(labels)
    accs = []
    for i in range(len(preds)):
        label = labels[i]
        pred = preds[i][:topk_]

        if label in pred:
            accs.append(1)
        else:
            accs.append(0)
    return np.mean(accs)

def evaluate_data_sequence(data_path):
    topks = [1, 3, 5, 10]
    files = os.listdir(data_path)
    data = []
    for file_ in files:
        print(file_)
        if file_.split('.')[-1] == 'csv':
            cur_data = pd.read_csv(os.path.join(args.data, file_))
            data.extend(cur_data.values.tolist())
    data = pd.DataFrame(data, columns = cur_data.columns)
    data = eval_cols(data, ['preds'])
    topks = [1, 3, 5, 10]
    for topk in topks:
        acc = evaluate(data, topk)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MM-RCR')
    parser.add_argument('--data', type=str, required=True, help='data path to save the evaluation results')
    parser.add_argument('--dataset', type=str, required=True, help='dataset type')
    args = parser.parse_args()

    if '500mt' in args.dataset:
        evaluate_data_sequence(args.data)
        return
    elif 'condition' in args.dataset:
        categories = ['c1', 's1', 's2', 'r1', 'r2']
    else:
        raise Exception(f'{args.dataset} not supported...')
    topks = [1, 3, 5, 10]
    files = os.listdir(args.data)
    data = []
    for file_ in files:
        print(file_)
        if file_.split('.')[-1] == 'csv':
            cur_data = pd.read_csv(os.path.join(args.data, file_))
            data.extend(cur_data.values.tolist())
        
    data = pd.DataFrame(data, columns = cur_data.columns)
    for ii in range(len(categories)):
        data = eval_cols(data, [f'pred_{categories[ii]}'])
        data[f'pred_{categories[ii]}'] = data[f'pred_{categories[ii]}'].apply(lambda x: heapq.nlargest(max(topks), range(len(x)), x.__getitem__))

    for topk_ in topks:
        print(f'-------top{topk_} accuracy------')
        accs = []
        for ii in range(len(categories)):
            preds = data[f'pred_{categories[ii]}'].values.tolist()
            labels = data[f'true_{categories[ii]}'].values.tolist()
            acc = eval_topk(topk_, preds, labels)
            print('{}: {:.4f}'.format(categories[ii], acc))
            accs.append(acc)
        print('average: {:.4f}'.format(np.mean(accs)))

                