"""Merge results from distributed training."""

import torch
import argparse
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

from lm_eval.tasks import get_task, ALL_TASKS
from utils import aggregate


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--results_dir', type=str, default='results/')
    parser.add_argument('--tasks', type=str, default='pile_all')
    parser.add_argument('--world_size', type=int, default=32)
    args = parser.parse_args()

    if args.tasks == 'pile_all':
        from utils import PILE_WEIGHTS
        task_names = PILE_WEIGHTS.keys()
    elif args.tasks == 'super_glue':
        task_names = ['boolq', 'cb', 'copa', 'multirc', 'record', 'rte', 'wic', 'wsc']
    else:
        from utils import pattern_match
        task_names = pattern_match(args.tasks.split(","), ALL_TASKS)

    metrics_before = {}
    metrics_after = {}

    for task_name in task_names:
        all_stats = []
        all_losses = []
        training_costs = []
        retrieval_costs = []

        task = get_task(task_name)(download=False)
        for rank in tqdm(range(args.world_size)):
            results_file = '%s/%s_%d.pth' % (args.results_dir, task_name, rank)
            try:
                results = torch.load(results_file)
                print('Found: %s' % (results_file))
                all_stats += results[0]
                all_losses += results[1]
                training_costs += results[2]
                retrieval_costs += results[3]
            except:
                print('Not found: %s' % (results_file))

        lengths = [len(all_stats[i]) for i in range(len(all_stats))]
        if len(lengths) == 0:
            print('No results found for %s' % (task_name))
        median = int(np.median(lengths))
        all_stats = [all_stats[i] for i in range(len(all_stats)) if len(all_stats[i]) == median]
        all_losses = [all_losses[i] for i in range(len(all_losses)) if len(all_losses[i]) == median]

        all_stats = np.array(all_stats)
        aggregate_stats = []
        for j in range(all_stats.shape[1]):
            aggregate_row = all_stats[:, j]
            aggregate_stats.append(aggregate(aggregate_row, task))

        metrics_before[task_name]= aggregate_stats[0]
        metrics_after[task_name] = aggregate_stats[-1]

        plot_stats = []
        metric_name = list(aggregate_stats[0].keys())[2]
        for aggregate_entry in aggregate_stats:
            plot_stats.append(aggregate_entry[metric_name])

        plt.figure()
        plt.plot(plot_stats)
        plt.ylabel(metric_name)
        plt.xlabel('training steps')
        plt.savefig('%s/%s_%s.pdf' % (args.results_dir, task_name, metric_name))

        plt.figure()
        all_losses = np.array(all_losses)
        plt.plot(np.mean(all_losses, axis=0))
        plt.ylabel('training loss')
        plt.savefig('%s/%s_%s.pdf' % (args.results_dir, task_name, 'train'))        

    # separate bar plot for each task
    for task_name in task_names:
        plt.figure()
        plt.title('Before vs after %s' % task_name)
        metric_name = list(metrics_before[task_name].keys())[2]
        before = metrics_before[task_name][metric_name]
        after = metrics_after[task_name][metric_name]
        plt.bar(['before', 'after'], [before, after], label=task_name)
        plt.ylabel(metric_name)
        plt.savefig('%s/before-after-%s.pdf' % (args.results_dir, task_name))
