import configparser
import json
import matplotlib.patheffects as pe
import matplotlib.pyplot as plt
import os, sys
import torch

from matplotlib.ticker import MultipleLocator

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from utils.utils import set_seed
import numpy as np
import os

linestyles = ['--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted', '-', '--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted']
markers = ['o', '^', 's', 'D', 'v', 'x', '*', 'p', 'o', '^', 's', 'D', 'v', 'p', '*', 'x']
colors = ['blue', 'green', 'purple', 'orange', 'brown', 'pink', 'red', 'gray', 'blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink', 'gray']

colors = ['#CFE5FD', '#FDF1DB', '#F1E7EF', '#6BA5D7', '#EDAD6D', '#AC78A9', '#945A28', 'gray', 'blue', 'green', 'red', 'purple']

colors = ['#BD9273', '#CCA29F', '#F3975F', '#54936D', '#98B8DD', '#A9B98B', '#797979', 'gray', 'blue', 'green', 'red', 'purple']

colors = ['#797979', '#A9B98B', '#98B8DD', '#54936D', '#BD9273', '#CCA29F', '#F3975F']

labels = [
    "Origin-1",
    "ReRoPE-1",
    "Dynamic-NTK-1",
    "Origin-2",
    "ReRoPE-2",
    "Dynamic-NTK-2"

]

label_setting = {}
for label, linestyle, marker, color in zip(labels, linestyles, markers, colors):
    label_setting[label] = {
        "linestyle": linestyle,
        "marker": marker,
        "color": color
    }

print(label_setting)



dimention_setting = [1, 6, 7, 9]


def read_config_file(config_path):
    if "../conf/" not in config_path:
        config_path = "../conf/" + config_path
    if ".ini" in config_path:
        config = configparser.ConfigParser()
        config.read(config_path)
    elif ".json" in config_path:
        with open(config_path, "r") as f:
            config = json.load(f)
    else:
        raise NotImplementedError("No implement read")
    return config



def get_hidden_states(files, labels):
    all_hidden_states = []
    for file, label_ in zip(files, labels):
        all_hidden_states_ = torch.load(file)
        all_hidden_states.append((all_hidden_states_, label_))
    return all_hidden_states

def get_value_from_dim(all_hidden_states, id):

    # matrixs = []
    #
    # layer_lists = []
    # plot_layers = 3  # 11
    # for layer_ in range(plot_layers):
    #     if layer_ == 2:
    #         for dim in range(0, 20, 1):
    #             layers_list = list(all_hidden_states[layer_][0][:, dim].cpu())
    #             layers_list2 = list(all_hidden_states2[layer_][0][:, dim].cpu())
    #             layers_list3 = list(all_hidden_states3[layer_][0][:, dim].cpu())
    # pass

    dim = dimention_setting[id]
    results = []
    for all_hidden_states_, label_ in all_hidden_states:
        for layer_index in [1]:
            y = all_hidden_states_[layer_index][0][:, dim].cpu()
            results.append((None, y, None, label_+"-{}".format(layer_index)))
    return results


def main(args=None):

    all_labels = []  # 用于存储所有图例标签
    handles = []

    # 创建大图和两行三列的子图
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(20, 5), sharey='all')

    all_hidden_states = get_hidden_states(args.files, args.labels)

    for id in range(4):
        min_y, max_y = float('inf'), float('-inf')
        results = get_value_from_dim(all_hidden_states, id)
        for result in results:
            x, y, var_values, label = result[0], result[1], result[2], result[3]
            min_y = min(min_y, min(y[10:-10]))  # Update min_y
            max_y = max(max_y, max(y[10:-10]))  # Update max_y
            if label in label_setting.keys():
                plot_line,  = axes[id].plot(y[10:-10], label=label, linewidth=1, marker=label_setting[label]["marker"], color=label_setting[label]["color"],
                                                  markersize=1)
            else:
                plot_line, = axes[id].plot(y[10:-10], label=label, linewidth=1,
                                            marker=label_setting["other"]["marker"],
                                            color=label_setting["other"]["color"], markersize=1)

            if label not in all_labels:
                all_labels.append(label)
                handles.append(plot_line)

        axes[id].set_title("dimension", fontsize=14)
        # axes[id].set_ylim([min_y, max_y])
        max_len = 16 * 1024
        x = np.arange(0, max_len, 1024)
        axes[id].set_xticks(np.array(x), [str(int(l / 1024)) + "k" if int(l / 1024) % 3 == 2 else "" for l in x], fontsize=12)

        # 优化子图效果，添加浅色边框
        axes[id].spines['top'].set_color('lightgrey')
        axes[id].spines['right'].set_color('lightgrey')
        axes[id].spines['bottom'].set_color('lightgrey')
        axes[id].spines['left'].set_color('lightgrey')

        axes[id].spines['right'].set_path_effects([pe.withStroke(linewidth=2, foreground='grey')])  # 添加立体效果
        axes[id].spines['bottom'].set_path_effects([pe.withStroke(linewidth=2, foreground='grey')])  # 添加立体效果

        # 最大训练长度
        if id == 1:
            axes[id].axvline(4096, color='white', linestyle='dashdot', linewidth=4.0, zorder=1, marker='o', markersize=4)
        else:
            axes[id].axvline(2048, color='white', linestyle='dashdot', linewidth=4.0, zorder=1, marker='o',
                              markersize=4)


    # 添加纵坐标标签，仅在两行子图的中间位置显示一次
    fig.text(0.02, 0.5, 'NLL', ha='center', va='center', rotation='vertical', fontsize=14)
    fig.text(0.5, 0.03, 'Token Length', ha='center', va='center', fontsize=14)

    # 调整子图之间的间距
    # plt.subplots_adjust(hspace=0.3)
    plt.subplots_adjust(wspace=0.1, hspace=0.3)
    plt.subplots_adjust(left=0.05, right=0.98, bottom=0.19, top=0.92)

    # plt.ylim(bottom=-0.5, top=0.7)

    fig.legend(handles=handles, labels=all_labels, loc='lower center', bbox_to_anchor=(0.5, 0.055), ncol=len(all_labels), fancybox=True,
               shadow=True)

    # 背景和网格线
    # 设置所有子图的背景色
    for ax in np.ravel(axes):
        ax.set_facecolor('#f8f8f8')  # 设置背景色

        # 添加网格线，并调整密度
        # ax.grid(True, linewidth=0.5, alpha=0.5)

        # ax.yaxis.grid(True, linewidth=0.5, alpha=0.5)  # 添加水平方向的网格线，并调整密度

        # ax.xaxis.set_major_locator(MultipleLocator(1000))

    plt.savefig('{}.png'.format("probe-validation-dimention"))
    # 显示图形
    plt.show()

if __name__ == "__main__":
    files_list = [
        "../test/old_hello-16000_llama2-7b-chat_saved_all_hidden_states-5.6.pth",
        "../test/rerope_hello-16000_llama2-7b-chat_saved_all_hidden_states-5.6.pth",
        "../test/dynamic-ntk_hello-16000_llama2-7b-chat_saved_all_hidden_states-5.6.pth"
    ]
    labels_list = [
        "Origin",
        "ReRoPE",
        "Dynamic-NTK"
    ]
    parser = argparse.ArgumentParser()
    parser.add_argument('--files', nargs='+', type=str, default=files_list)
    parser.add_argument('--labels', nargs='+', type=str, default=labels_list)
    args = parser.parse_args()
    main(args)
