import configparser
import json

import matplotlib.pyplot as plt
import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from utils.utils import set_seed
import numpy as np
import os

import matplotlib.patheffects as pe


linestyles = ['--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted', '-', '--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted']
markers = ['o', '^', 's', 'D', 'v', 'x', '*', 'p', 'o', '^', 's', 'D', 'v', 'p', '*', 'x']
colors = ['blue', 'green', 'purple', 'orange', 'brown', 'pink', 'red', 'gray', 'blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink', 'gray']

colors = ['#CFE5FD', '#FDF1DB', '#F1E7EF', '#6BA5D7', '#EDAD6D', '#AC78A9', '#945A28', 'gray', 'blue', 'green', 'red', 'purple']

colors = ['#BD9273', '#CCA29F', '#F3975F', '#54936D', '#98B8DD', '#A9B98B', '#797979', 'gray', 'blue', 'green', 'red', 'purple']

colors = ['#797979', '#A9B98B', '#98B8DD', '#54936D', '#BD9273', '#CCA29F', '#F3975F']



def read_config_file(config_path):
    if ".ini" in config_path:
        config = configparser.ConfigParser()
        config.read(config_path)
    elif ".json" in config_path:
        with open(config_path, "r") as f:
            config = json.load(f)
    else:
        raise NotImplementedError("No implement read")
    return config



def main(args=None):

    if args.config_file:
        config = read_config_file(args.config_file)
        files = config['General']['files']
        labels = config['General']['labels']
        smooth_gamma = config['General']['smooth_gamma']
        model_name = config['General']['model_name']
    else:
        raise FileNotFoundError("no config file")

    fig, ax1 = plt.subplots(figsize=(8, 5))

    for _filename, _label in zip(files, labels):

        if "record" not in _filename:
            names = _filename.split("/")
            names[-1] = "record_" + names[-1]
            _filename = "/".join(names)

        with open(_filename, "rb") as f:
            data = pickle.load(f)["all_length_acc"]

            x = [-1]
            y = [1]
            pre_value = []
            var_values = []
            for length, value in data.items():

                # 截断长度
                if length > 60*1024:
                    break

                if abs(length - x[-1]) > 10:
                    x.append(length)
                    pre_value = value
                else:
                    # 跟前一个合并
                    pre_value.extend(value)
                    y.pop(-1)
                    var_values.pop(-1)

                mean = np.nanmean(pre_value)
                # var = np.nanstd(pre_value)
                var = np.nanvar(pre_value)

                # 补充初始范围 0-2k
                if length < 2000:
                    mean = 1
                    var = 0


                y.append(
                    mean * (1 - smooth_gamma) + y[-1] * smooth_gamma
                )
                var_values.append(var)

            y.pop(0)
            x.pop(0)

            y = np.array(y)
            var_values = np.array(var_values)
            # if "Mesa" in _label:
            #     ax1.plot(x, y, label=_label, linewidth=1.5, color="r")
            # else:
            enhanced_id = 4
            mesa_id = 6
            if "enhanced" in _label:
                ax1.plot(x, y, label="Mesa-Extrapolation (enhanced-prompt)", linewidth=2, marker=markers[enhanced_id], markersize=5, linestyle=linestyles[enhanced_id], color=colors[enhanced_id])
                ax1.fill_between(x, y - (var_values), y + (var_values), alpha=0.1, color=colors[enhanced_id])

            else:
                ax1.plot(x, y, label=_label, linewidth=2, marker=markers[mesa_id], markersize=5, linestyle=linestyles[mesa_id], color=colors[mesa_id])

                ax1.fill_between(x, y-(var_values), y+(var_values), alpha=0.1, color=colors[mesa_id])

    # # 设置坐标轴标签
    # plt.xlabel("token length", fontsize=10)
    # plt.ylabel("accuracy")
    # plt.xticks(rotation=45, ha='left')
    # plt.xticks(np.array(x)-1024, [str(int(l / 1024)) + "k" for l in x], fontsize=9)
    # # 设置标题
    # plt.title("passkey retrieval task: {} (smooth: {})".format(model_name, smooth_gamma))
    # plt.subplots_adjust(left=0.1, right=0.9, bottom=0.12, top=0.9)

    # 优化子图效果，添加浅色边框
    ax1.spines['top'].set_color('lightgrey')
    ax1.spines['right'].set_color('lightgrey')
    ax1.spines['bottom'].set_color('lightgrey')
    ax1.spines['left'].set_color('lightgrey')

    ax1.spines['right'].set_path_effects([pe.withStroke(linewidth=2, foreground='grey')])  # 添加立体效果
    ax1.spines['bottom'].set_path_effects([pe.withStroke(linewidth=2, foreground='grey')])  # 添加立体效果

    ax1.set_facecolor('#f8f8f8')  # 设置背景色
    # 添加网格线，并调整密度
    ax1.grid(True, linewidth=0.5, alpha=0.5)
    # ax1.yaxis.grid(True, linewidth=0.5, alpha=0.5)  # 添加水平方向的网格线，并调整密度
    # 添加水平方向的网格线，并调整密度
    ax1.axvline(2048, color='gray', linewidth=1.0, zorder=1,
                           markersize=4)

    # 论文做图
    # 设置坐标轴标签
    plt.xlabel("Token Length", fontsize=14)
    plt.ylabel("Accuracy", fontsize=14)
    # plt.xticks(rotation=45, ha='center')

    # 指定 y 轴的坐标刻度为 0 到 1
    plt.ylim(-0.05, 1.05)

    # 设置刻度
    # x = np.arange(0, 31 * 1024, 1024)
    x = np.arange(0, 61*1024, 1024)
    plt.xticks(np.array(x), [str(int(l / 1024)) + "k" if int(l / 1024) % 5 == 0 else "" for l in x], fontsize=12)
    # 设置标题
    # plt.title("passkey retrieval task: {} (smooth: {})".format(model_name, smooth_gamma))
    plt.subplots_adjust(left=0.08, right=0.98, bottom=0.14, top=0.98)






    plt.legend()

    plt.savefig('{}.png'.format(model_name))

    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_file", type=str, default="../conf/llama-3b-ablation-result51.json")
    # default="../conf/enhanced-llama-3b-passkey-mean-var-result31.json"
    args = parser.parse_args()
    main(args)
