import configparser
import json

import matplotlib.pyplot as plt
import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from utils.utils import set_seed, read_config_file
import numpy as np
import os

smooth_gamma = 0.0


def main(args=None):

    if args.config_file:
        config = read_config_file(args.config_file)
        files = config['General']['files']
        labels = config['General']['labels']
        smooth_gamma = config['General']['smooth_gamma']
    else:
        raise FileNotFoundError("no config file")

    width = 80 #0.2
    num_datasets = len(files)
    index = 0

    bleus_record = []
    rouges_record = []
    max_rouges_record = []

    variance_values_record = []

    for _filename, _label in zip(files, labels):
        with open(_filename, "rb") as f:
            data = pickle.load(f)["all_length_score"]
            smoothed = [0]
            var_values = []

            rouges = []
            max_rouges = []

            keys = list(data.keys())
            keys.sort()
            for key in keys:
                def get_rouges():
                    res = []
                    f1 = 0.0
                    nums = len(data[key]['rouge-1'])
                    for num in range(nums):
                        f1 += data[key]['rouge-1'][num]['f']
                    f1 = f1 / nums
                    res.append(round(f1*100, 1))

                    f1 = 0.0
                    nums = len(data[key]['rouge-2'])
                    for num in range(nums):
                        f1 += data[key]['rouge-2'][num]['f']
                    f1 = f1 / nums
                    res.append(round(f1*100, 1))

                    f1 = 0.0
                    nums = len(data[key]['rouge-l'])
                    for num in range(nums):
                        f1 += data[key]['rouge-l'][num]['f']
                    f1 = f1 / nums
                    res.append(round(f1*100, 1))

                    return res

                def get_max_rouges():
                    res = []
                    f1 = 0.0
                    nums = len(data[key]['rouge-1'])
                    for num in range(nums):
                        f1 = max(f1, data[key]['rouge-1'][num]['f'])
                    res.append(round(f1*100, 1))

                    f1 = 0.0
                    nums = len(data[key]['rouge-2'])
                    for num in range(nums):
                        f1 = max(f1, data[key]['rouge-2'][num]['f'])
                    res.append(round(f1*100, 1))

                    f1 = 0.0
                    nums = len(data[key]['rouge-l'])
                    for num in range(nums):
                        f1 = max(f1, data[key]['rouge-l'][num]['f'])
                    res.append(round(f1*100, 1))

                    return res

                mean = data[key]["bleu-mean"]
                smoothed.append(
                    mean * (1-smooth_gamma) + smoothed[-1] * smooth_gamma
                                )
                # var_ = np.nanvar(data[key]["bleu"])
                var_ = np.nanstd(data[key]["bleu"])
                var_values.append(var_)

                rouges_ = get_rouges()
                max_rouges_ = get_max_rouges()

                rouges_ = "/".join([str(rg) for rg in rouges_])
                rouges.append(rouges_)

                max_rouges_ = "/".join([str(rg) for rg in max_rouges_])
                max_rouges.append(max_rouges_)

            smoothed.pop(0)
            x = np.array(keys)
            y = np.array(smoothed)
            var_values = np.array(var_values)

            # 绘制均值和方差
            mean_values = y
            variance_values = var_values
            # plt.plot(x, y, label=_label, linewidth=1.5)
            # plt.bar(x+index*width, mean_values, width=width, label=_label)
            plt.errorbar(x+index*width, mean_values, yerr=variance_values, fmt='o', label=_label)

            index += 1

            variance_values = ", ".join([str(round(val, 3)) for val in variance_values])
            mean_values = ", ".join([str(round(val, 3)) for val in mean_values])
            rouge_value = ", ".join(rouges)
            max_rouge_value = ", ".join(max_rouges)

            # print("{}: bleu-mean-values: {}".format(_label, mean_values))
            # # print("{}, {}".format(_label, variance_values))
            # print("{}: rouge-value: {}".format(_label, rouge_value))
            # print("{}: max-rouge-value: {}".format(_label, max_rouge_value))

            bleus_record.append("{}: bleu-mean-values: {}".format(_label, mean_values))
            rouges_record.append("{}: rouge-value: {}".format(_label, rouge_value))
            max_rouges_record.append("{}: max-rouge-value: {}".format(_label, max_rouge_value))

            variance_values_record.append("{}: variance-values: {}".format(_label, variance_values))

    print("\n\n\n")
    print("\n".join(bleus_record))
    print("################################################\n")
    print("\n".join(variance_values_record))
    print("################################################\n")
    print("\n".join(rouges_record))
    print("################################################\n")
    print("\n".join(max_rouges_record))

    # 设置坐标轴标签
    plt.xlabel("token length")
    plt.ylabel("mean with variance")
    # plt.ylim(bottom=0, top=15)
    # 设置标题
    plt.title("bleu (smooth:{})".format(smooth_gamma))
    # plt.xticks(x,x)
    categories = [str(int(l / 1024)) + "k" for l in x]
    plt.xticks(x, categories)
    # plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)
    plt.legend()
    # plt.grid(axis='y')
    plt.savefig('rouge.png')
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_file", type=str, default="../conf/llama2-7b-chat-rouge-govreport-result12.json")
    args = parser.parse_args()
    main(args)