import configparser
import json

import matplotlib.pyplot as plt
import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from utils.utils import set_seed
import numpy as np
import os


def read_config_file(config_path):
    if ".ini" in config_path:
        config = configparser.ConfigParser()
        config.read(config_path)
    elif ".json" in config_path:
        with open(config_path, "r") as f:
            config = json.load(f)
    else:
        raise NotImplementedError("No implement read")
    return config



def main(args=None):

    if args.config_file:
        config = read_config_file(args.config_file)
        files = config['General']['files']
        labels = config['General']['labels']
        smooth_gamma = config['General']['smooth_gamma']
    else:
        raise FileNotFoundError("no config file")

    for _filename, _label in zip(files, labels):
        with open(_filename, "rb") as f:
            data = pickle.load(f)["length_mean_var"]
            x = []
            y = [1]
            var_values = []
            for length, value in data.items():
                x.append(length)
                y.append(
                    value["mean"] * (1-smooth_gamma) + y[-1] * smooth_gamma
                )
                var_values.append(value["var"])
            y.pop(0)

            y = np.array(y)
            var_values = np.array(var_values)
            plt.plot(x, y, label=_label, linewidth=1.5)
            plt.fill_between(x, y-(var_values), y+(var_values), alpha=0.1)

    # 设置坐标轴标签
    plt.xlabel("token length")
    plt.ylabel("accuracy")
    # 设置标题
    plt.title("passkey retrieval task (smooth: {})".format(smooth_gamma))

    plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)

    plt.legend()
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_file", type=str, default="../conf/passkey-mean-var-result1.json")
    args = parser.parse_args()
    main(args)
