"""Merge results from distributed evaluation."""


import torch
import argparse

from lm_eval.tasks import get_task, ALL_TASKS
from utils import aggregate


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--results_dir', type=str, default='results/')
    parser.add_argument('--tasks', type=str, default='pile_all')
    parser.add_argument('--world_size', type=int, default=32)
    args = parser.parse_args()

    if args.tasks == 'pile_all':
        from utils import PILE_WEIGHTS
        task_names = PILE_WEIGHTS.keys()
    elif args.tasks == 'super_glue':
        task_names = ['boolq', 'cb', 'copa', 'multirc', 'record', 'rte', 'wic', 'wsc']
    else:
        from utils import pattern_match
        task_names = pattern_match(args.tasks.split(","), ALL_TASKS)

    metrics = {}
    for task_name in task_names:
        print(task_name)
        all_stats = []
        task = get_task(task_name)(download=False)
        for rank in range(args.world_size):
            results_file = '%s/%s_%d.pth' % (args.results_dir, task_name, rank)
            try:
                results = torch.load(results_file)
                # print('Found: %s' % (results_file))
                all_stats += results[0]
            except:
                print('Not found: %s' % (results_file))
        metrics[task_name] = aggregate(all_stats, task)
    
    for task_name in task_names:
        print('%s %.2f' % (task_name, metrics[task_name]['bits_per_byte']))
        