import json
import argparse

import numpy as np
import matplotlib.pyplot as plt

from itertools import groupby


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--results', help='Path to throughput.txt file (list of json).')
    parser.add_argument('--device', choices=('cuda', 'cpu'),
                        default='cuda', type=str,
                        help='Device to plot throughput for.')

    args = parser.parse_args()

    rows = []
    with open(args.results, 'r') as f:
        rows = []
        for line in f:
            row = json.loads(line)
            row['model'] = row['model'].split('.')[-1]
            row['model'] = '%s-r=%d' % (row['model'], row['ncomponent'])

            ss = '"device": "%s"' % args.device
            if ss in line:
                rows.append(row)

    fig, ax = plt.subplots(figsize=(10, 6))

    for model, stats in groupby(rows, lambda x: x['model']):

        stats = tuple(stats)
        n_tokens = np.array(tuple(s['ntoken'] for s in stats))
        tps = np.array(tuple(s['tokens_per_second'] for s in stats))

        ax.plot(n_tokens, tps, '-o', label=model)
    ax.set_ylim([0, None])
    ax.set_ylabel('Tokens Per Second', fontsize=24)
    ax.set_xlabel('Num Tokens Generated by Head', fontsize=24)
    ax.set_title('Throughput for models on %s' % args.device, fontsize=30)
    ax.legend(fontsize=20, loc='lower right')
    plt.tight_layout()
    plt.show()
