import configparser
import json
import matplotlib.patheffects as pe
import matplotlib.pyplot as plt
import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from utils.utils import set_seed
import numpy as np
import os

linestyles = ['--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted', '-', '--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted']
markers = ['o', '^', 's', 'D', 'v', 'p', '*', 'x', 'o', '^', 's', 'D', 'v', 'p', '*', 'x']
colors = ['blue', 'green', 'purple', 'orange', 'brown', 'pink', 'red', 'gray', 'blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink', 'gray']


labels = [
    "Origin",
    "ReRoPE",
    "Leaky-ReRoPE",
    "Dynamic-NTK",
    "LM-Infinite",
    "Streaming-LLM",
    "Mesa-Extrapolation",
    "other"
]

label_setting = {}
for label, linestyle, marker, color in zip(labels, linestyles, markers, colors):
    label_setting[label] = {
        "linestyle": linestyle,
        "marker": marker,
        "color": color
    }

print(label_setting)




def read_config_file(config_path):
    if "../conf/" not in config_path:
        config_path = "../conf/" + config_path
    if ".ini" in config_path:
        config = configparser.ConfigParser()
        config.read(config_path)
    elif ".json" in config_path:
        with open(config_path, "r") as f:
            config = json.load(f)
    else:
        raise NotImplementedError("No implement read")
    return config


def handle_subplot_data(files, labels, smooth_gamma):
    results = []
    for _filename, _label in zip(files, labels):
        if "record" not in _filename:
            names = _filename.split("/")
            names[-1] = "record_" + names[-1]
            _filename = "/".join(names)

        with open(_filename, "rb") as f:
            data = pickle.load(f)["all_length_acc"]

            x = [-1]
            y = [1]
            pre_value = []
            var_values = []
            for length, value in data.items():

                # 截断长度
                if length > 30*1024:
                    break

                if abs(length - x[-1]) > 10:
                    x.append(length)
                    pre_value = value
                else:
                    # 跟前一个合并
                    pre_value.extend(value)
                    y.pop(-1)
                    var_values.pop(-1)

                mean = np.nanmean(pre_value)
                # var = np.nanstd(pre_value)
                var = np.nanvar(pre_value)

                # 补充初始范围 0-2k
                if "llama" in _filename or "vicuna" in _filename or "mpt" in _filename:
                    if length < 2000:
                        mean = 1
                        var = 0
                if "llama2" in _filename:
                    if length < 4000:
                        mean = 1
                        var = 0


                y.append(
                    mean * (1 - smooth_gamma) + y[-1] * smooth_gamma
                )
                var_values.append(var)

            y.pop(0)
            x.pop(0)

            y = np.array(y)
            var_values = np.array(var_values)

            results.append((x, y, var_values, _label))

    return results

def main(args=None):

    all_labels = []  # 用于存储所有图例标签
    handles = []

    # 创建大图和两行三列的子图
    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(20, 8), sharey='all')

    files_list = args.files_list
    for id, config_file in enumerate(files_list):
        config = read_config_file(config_file)
        files = config['General']['files']
        labels = config['General']['labels']
        smooth_gamma = config['General']['smooth_gamma']
        model_name = config['General']['model_name']

        results = handle_subplot_data(files=files, labels=labels, smooth_gamma=smooth_gamma)

        # 在每个子图中绘制数据
        idx = id // 3
        idy = id % 3
        for result in results:
            x, y, var_values, label = result[0], result[1], result[2], result[3]

            if "pythia-12b" in config_file and ("Mesa-Extrapolation" in label or "Origin" in label):
                # x = x[:-1]
                # y = y[:-1]
                # var_values = var_values[:-1]
                x.append(x[-1]+984)
                y = np.append(y, y[-2])
                var_values = np.append(var_values, var_values[-2])


            if label in label_setting.keys():
                plot_line,  = axes[idx, idy].plot(x, y, label=label, linewidth=2, linestyle=label_setting[label]["linestyle"], marker=label_setting[label]["marker"], color=label_setting[label]["color"],
                                                  markersize=6)
                # axes[idx, idy].fill_between(x, y - (var_values), y + (var_values), label=label, alpha=0.1,
                #                             color=label_setting[label]["color"])
            else:
                plot_line,  = axes[idx, idy].plot(x, y, label=label, linewidth=2, linestyle=label_setting["other"]["linestyle"],
                                    marker=label_setting["other"]["marker"], color=label_setting["other"]["color"], markersize=6)

            axes[idx, idy].fill_between(x, y-(var_values), y+(var_values), label=label, alpha=0.1, color=plot_line.get_color())

            if label not in all_labels:
                all_labels.append(label)
                handles.append(plot_line)


        axes[idx, idy].set_title(model_name, fontsize=14)
        max_len = (x[-1] // 1024 + 1) * 1024
        x = np.arange(0, max_len, 1024)

        if idx == 1 and idy == 1:
            axes[idx, idy].set_xticks(np.array(x), [str(int(l / 1024)) + "k" if int(l / 1024) % 2 == 1 else "" for l in x], fontsize=12)
        elif idx == 1 and idy == 2:
            max_len = (x[-1] // 1024 + 2) * 1024
            x = np.arange(1024, max_len, 1024)
            axes[idx, idy].set_xticks(np.array(x),
                                      [str(int(l / 1024)) + "k" if int(l / 1024) > 0 and int(l / 1024) < 7 else "" for l in x],
                                      fontsize=12)

        else:
            axes[idx, idy].set_xticks(np.array(x),
                                      [str(int(l / 1024)) + "k" if int(l / 1024) % 3 == 2 else "" for l in x],
                                      fontsize=12)

        # 优化子图效果，添加浅色边框
        axes[idx, idy].spines['top'].set_color('lightgrey')
        axes[idx, idy].spines['right'].set_color('lightgrey')
        axes[idx, idy].spines['bottom'].set_color('lightgrey')
        axes[idx, idy].spines['left'].set_color('lightgrey')

        axes[idx, idy].spines['right'].set_path_effects([pe.withStroke(linewidth=2, foreground='grey')])  # 添加立体效果
        axes[idx, idy].spines['bottom'].set_path_effects([pe.withStroke(linewidth=2, foreground='grey')])  # 添加立体效果

        # 最大训练长度
        if idy == 1 and idx == 0:
            axes[idx, idy].axvline(4096, color='white', linestyle='dashdot', linewidth=4.0, zorder=1, marker='o',
                              markersize=4)
        else:
            axes[idx, idy].axvline(2048, color='white', linestyle='dashdot', linewidth=4.0, zorder=1, marker='o',
                              markersize=4)

    # 设置所有子图的背景色
    for ax in np.ravel(axes):
        ax.set_facecolor('#f8f8f8')  # 设置背景色
        # # 将背景色调浅一些，例如，将透明度设为0.9
        # from matplotlib.colors import to_rgba
        # lighter_background_color = to_rgba('#f8f8f8', alpha=0.9)
        # ax.set_facecolor(lighter_background_color)

    # 添加纵坐标标签，仅在两行子图的中间位置显示一次
    fig.text(0.01, 0.5, 'Accuracy', ha='center', va='center', rotation='vertical', fontsize=14)
    fig.text(0.5, 0.03, 'Token Length', ha='center', va='center', fontsize=14)

    # 调整子图之间的间距
    # plt.subplots_adjust(hspace=0.3)
    plt.subplots_adjust(wspace=0.1, hspace=0.3)
    plt.subplots_adjust(left=0.05, right=0.98, bottom=0.14, top=0.92)
    # plt.legend()
    # 在大图的正下方放置一个统一的图例
    # fig.legend(loc='upper center', bbox_to_anchor=(0.5, 0.05), fancybox=False, shadow=True)
    # fig.legend(labels=all_labels, loc='lower center', bbox_to_anchor=(0.5, 0.05), ncol=len(all_labels), fancybox=True,
    #            shadow=True)

    fig.legend(handles=handles, labels=all_labels, loc='lower center', bbox_to_anchor=(0.5, 0.05), ncol=len(all_labels), fancybox=True,
               shadow=True)

    plt.savefig('{}.png'.format("Accuracy-all-6"))
    # 显示图形
    plt.show()


    # plt.xlabel("Token Length", fontsize=12)
    # plt.ylabel("Accuracy", fontsize=12)
    # plt.xticks(rotation=45, ha='center')
    # x = np.arange(0, 31*1024, 1024)
    # plt.xticks(np.array(x), [str(int(l / 1024)) + "k" for l in x], fontsize=10)
    # # 设置标题
    # # plt.title("passkey retrieval task: {} (smooth: {})".format(model_name, smooth_gamma))
    # plt.subplots_adjust(left=0.08, right=0.98, bottom=0.14, top=0.98)
    # plt.legend()
    # plt.savefig('{}.png'.format(model_name))
    # plt.show()


if __name__ == "__main__":
    files_list = [
        "passkey-mean-var-result1.json",
        "passkey-mean-var-result3.json",
        "passkey-mean-var-result2.json",

        "mpt-7b-passkey-mean-var-result13.json",
        "pythia-6.9b-passkey-mean-var-result16.json",
        "pythia-12b-passkey-mean-var-result19.json"
    ]
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_file", type=str, default="../conf/mpt-7b-passkey-mean-var-result13.json")
    parser.add_argument('--files_list', nargs='+', type=str, default=files_list)
    args = parser.parse_args()
    main(args)
