# coding=utf-8
from __future__ import absolute_import, unicode_literals

import re
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import pandas as pd


def main():
    perplexity_re = re.compile(r"perplexity=(?P<perplexity>[\d.]+)")
    step_re = re.compile(r"Step (?P<step>[\d]+)")
    finally_perplex_re = re.compile(r"perplexity:(?P<perplexity>[\d.]+)")
    exp_re = re.compile(r"(?P<method>[\w]+),B=(?P<batch>\d+),S=(?P<seq_length>\d+)")
    # MODIFY THIS when change methods;search 'exp' on replays
    methods = [
        # {"method": "DDP", "batch": 64, "seq_length": 1024},  # single=4,machine=16
        # {"method": "MQSP", "batch": 64, "seq_length": 1024},
        # {"method": "MQSP", "batch": 64, "seq_length": 2048},
        # {"method": "MQSP", "batch": 64, "seq_length": 4096},
    ]
    parser = ArgumentParser()
    parser.add_argument("logfile")
    parser.add_argument(
        "--method", required=True, default="", help="split exp by ';', format like:'DDP,B=64,S=1024;MQSP,B=64,S=1024'"
    )
    args = parser.parse_args()
    data = []
    method_set = set()
    finally_perplexes = {}
    for method_str in args.method.split(";"):
        method_group = re.search(exp_re, method_str)
        if method_group is not None:
            method_dict = method_group.groupdict()
            method_dict["batch"] = int(method_dict["batch"])
            method_dict["seq_length"] = int(method_dict["seq_length"])
            methods.append(method_dict)
        else:
            raise ValueError("invalid method define:%s" % method_str)
    with open(args.logfile) as fin:
        exp_idx = 0
        length_method = len(methods)
        for line in fin:
            exp = methods[exp_idx]
            exp_name = "%s B=%d S=%d" % (exp["method"], exp["batch"], exp["seq_length"])

            if "train done" in line:
                finally_perplex_group = re.search(finally_perplex_re, line)
                finally_perplex = float(finally_perplex_group.groups()[0])
                finally_perplexes[exp_name] = finally_perplex
                exp_idx += 1
                if exp_idx > length_method - 1:
                    break
                else:
                    continue
            perplexity_group = re.search(perplexity_re, line)
            if perplexity_group is None:
                # print("not found perplex", line)
                continue
            perplexity = float(perplexity_group.groups()[0])
            step_group = re.search(step_re, line)
            if step_group is None:
                print("not found step")
            step = int(step_group.groups()[0])
            tokens = step * exp["batch"] * exp["seq_length"]
            method_set.add(exp_name)
            data.append({"exp": exp_name, "tokens": tokens, "perplexity": perplexity})
    df = pd.DataFrame(data)
    _, ax = plt.subplots()
    for exp_name in method_set:
        exp_df = df[df["exp"] == exp_name]
        ax.plot(exp_df["tokens"] / 1e6, exp_df["perplexity"], label=exp_name)
        # ax.scatter(exp_df["tokens"].iloc[-1] / 1e6, [finally_perplexes[exp_name]], label=exp_name)
    print(finally_perplexes)
    ax.set_xlabel("tokens(M)")
    ax.set_ylabel("perplexity")
    ax.set_yscale("log")
    ax.legend()
    plt.show()


if __name__ == "__main__":
    main()
