"""Merge results from distributed evaluation."""

import torch
import argparse
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

from lm_eval.tasks import get_task, ALL_TASKS
from utils import aggregate

from utils import PILE_WEIGHTS

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--results_dir', type=str, default='results/')
    parser.add_argument('--tasks', type=str, default='pile_all')
    parser.add_argument('--world_size', type=int, default=32)
    args = parser.parse_args()

    if args.tasks == 'pile_all':
        from utils import PILE_WEIGHTS
        task_names = PILE_WEIGHTS.keys()
    elif args.tasks == 'super_glue':
        task_names = ['boolq', 'cb', 'copa', 'multirc', 'record', 'rte', 'wic', 'wsc']
    else:
        from utils import pattern_match
        task_names = pattern_match(args.tasks.split(","), ALL_TASKS)

    metrics_before = {}
    metrics_after = {}

    pile_all = 0

    for task_name in task_names:
        all_stats = []
        all_losses = []
        training_costs = []
        retrieval_costs = []

        task = get_task(task_name)(download=False)
        for rank in tqdm(range(args.world_size)):
            results_file = '%s/%s_%d.pth' % (args.results_dir, task_name, rank)
            try:
                results = torch.load(results_file)
                # print('Found: %s' % (results_file))
                all_stats += results[0]
            except:
                print('Not found: %s' % (results_file))

        all_stats = np.array(all_stats)
        aggregate_stats = []
        for j in range(all_stats.shape[1]):
            aggregate_row = all_stats[:, j]
            aggregate_stats.append(aggregate(aggregate_row, task))

        plot_stats = []
        metric_name = list(aggregate_stats[0].keys())[2]
        for aggregate_entry in aggregate_stats:
            plot_stats.append(aggregate_entry[metric_name])

        metrics_before[task_name] = plot_stats[0]
        metrics_after[task_name] = min(plot_stats)

        plt.figure()
        plt.plot(np.arange(0, 0.2, 0.01), plot_stats)
        plt.ylabel(metric_name)
        plt.xlabel('alpha')
        plt.savefig('%s/alpha_%s_%s.pdf' % (args.results_dir, task_name, metric_name))

        pile_all += np.array(plot_stats) * PILE_WEIGHTS[task_name]

    for task_name in task_names:
        print(task_name, metrics_before[task_name], metrics_after[task_name])

    for task_name in task_names:
        print('%s %.2f' % (task_name, metrics_after[task_name]))

    print(pile_all)
