import configparser
import json

import matplotlib.pyplot as plt
import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from utils.utils import set_seed, read_config_file
import numpy as np
import os
import matplotlib.patheffects as pe


linestyles = ['--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted', '-', '--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted']
markers = ['o', '^', 's', 'D', 'v', 'x', '*', 'p', 'o', '^', 's', 'D', 'v', 'p', '*', 'x']
colors = ['#797979', '#A9B98B', '#98B8DD', '#54936D', '#BD9273', '#CCA29F', '#F3975F']


labels = [
    "Origin",
    "ReRoPE",
    "Leaky-ReRoPE",
    "Dynamic-NTK",
    "LM-Infinite",
    "Streaming-LLM",
    "Mesa-Extrapolation",
    "other"
]

label_setting = {}
for label, linestyle, marker, color in zip(labels, linestyles, markers, colors):
    label_setting[label] = {
        "linestyle": linestyle,
        "marker": marker,
        "color": color
    }

print(label_setting)

def main(args=None):

    if args.config_file:
        config = read_config_file(args.config_file)
        files = config['General']['files']
        labels = config['General']['labels']
        smooth_gamma = config['General']['smooth_gamma']
        model_name = config['General']['model_name']
    else:
        raise FileNotFoundError("no config file")

    fig, ax = plt.subplots()

    for _filename, _label in zip(files, labels):
        with open(_filename, "rb") as f:
            data = pickle.load(f)["nll_stats_token"]
            smoothed = [0]
            var_values = []
            for _d in data.values():
                smoothed.append(
                    _d["mean"] * (1 - smooth_gamma) +
                    smoothed[-1] * smooth_gamma
                )
                var_values.append(_d["var"])
            smoothed.pop(0)

            y = np.array(smoothed)
            var_values = np.array(var_values)
            if _label in label_setting.keys():
                ax.plot(y, label=_label, linewidth=1,
                         marker=label_setting[label]["marker"], color=label_setting[_label]["color"], markersize=1)
            # plt.fill_between(y - (var_values), y + (var_values), alpha=0.1, color=label_setting[_label]["color"])



    # 论文做图
    # 设置坐标轴标签
    plt.xlabel("Token Length", fontsize=14)
    plt.ylabel("NLL", fontsize=14)
    plt.ylim(bottom=0, top=15)

    x = np.arange(0, 16*1024, 1024)
    plt.xticks(np.array(x), [str(int(l / 1024)) + "k" if int(l / 1024) % 3 == 2 else "" for l in x],
                         fontsize=12)

    ax.yaxis.grid(True, linewidth=0.5, alpha=0.5)  # 添加水平方向的网格线，并调整密度

    plt.subplots_adjust(left=0.09, right=0.97, bottom=0.12, top=0.93)

    # plt.title("mpt-7b", fontsize=14)

    # 优化子图效果，添加浅色边框
    ax.spines['top'].set_color('lightgrey')
    ax.spines['right'].set_color('lightgrey')
    ax.spines['bottom'].set_color('lightgrey')
    ax.spines['left'].set_color('lightgrey')

    ax.spines['right'].set_path_effects([pe.withStroke(linewidth=3, foreground='lightgrey')])  # 添加立体效果
    ax.spines['bottom'].set_path_effects([pe.withStroke(linewidth=3, foreground='lightgrey')])
    ax.set_facecolor('#f8f8f8')  # 设置背景色

    # 最大训练长度
    ax.axvline(2048, color='white', linewidth=4.0, linestyle='dashdot',  zorder=1, marker='o',
                      markersize=4)

    plt.legend()
    plt.savefig('ppl_{}.png'.format(model_name))
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_file", type=str, default="../conf/mpt-7b-ppl-pile-result17.json")
    args = parser.parse_args()
    main(args)
