import glob
import numpy as np
import json
import pandas as pd
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--json-path", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)
parser.add_argument("--bon-res", type=bool, default=False)

args = parser.parse_args()


useful_cipai = ["蝶恋花", "苏幕遮", "浣溪沙", "沁园春", "卜算子", "水调歌头", "江城子", "喜迁莺", "念奴娇", "青玉案",
                "渔家傲", "忆江南", "满江红", "永遇乐", "一剪梅", "定风波", "虞美人", "鹧鸪天", "贺新郎", "西江月", "采桑子",
                "钗头凤", "长相思", "捣练子"]

def save_json(data, file_name):
    json_object = json.dumps(data, ensure_ascii = False)
    with open(file_name, 'w', encoding='utf-8') as f:
        f.write(json_object)

def read_json_lines(file_name):
    json_data = []
    with open(file_name, 'r', encoding='utf-8') as f:
        for line in f:
            e_d = json.loads(line)
            json_data.append(e_d)
    return json_data

def save_jsonl(data, file_name):
    with open(file_name, "w") as file:
        for d in data:
            json_data = json.dumps(d, ensure_ascii=False)
            file.write(json_data + "\n")

def read_json(file_name):
    with open(file_name, 'r') as file:
        data = json.load(file)
    return data

def get_usable_rate(reward_dic):
    # "rewards": {
    #         "标题": 0.0,
    #         "行数": null,
    #         "每行字数": 1.0,
    #         "押韵": null,
    #         "平仄": null,
    #         "风格类型": null,
    #         "创作对象": null,
    #         "场景描写": 1.0,
    #         "表达主题": null,
    #         "包含元素": null
    # }
    res = {
        "标题": [0, 0],
        "行数": [0, 0],
        "每行字数": [0, 0],
        "总字数": [0, 0],
        "押韵": [0, 0],
        "平仄": [0, 0],
        "风格类型": [0, 0],
        "创作对象": [0, 0],
        "场景描写": [0, 0],
        "表达主题": [0, 0],
        "包含元素": [0, 0],

    }
    bon_res = {
        "标题": [0, 0],
        "行数": [0, 0],
        "每行字数": [0, 0],
        "总字数": [0, 0],
        "押韵": [0, 0],
        "平仄": [0, 0],
        "风格类型": [0, 0],
        "创作对象": [0, 0],
        "场景描写": [0, 0],
        "表达主题": [0, 0],
        "包含元素": [0, 0],
    }
    for k, v in reward_dic.items():
        if k == "all":
            for r in v:
                for k1, v1 in r.items():
                    if k1 in res and v1 is not None:
                        res[k1][0] += 1
                        res[k1][1] += float(v1)
        else:
            used_keys = {k1: 0.0 for k1 in v[0].keys() if k1 in bon_res and v[0][k1] is not None}
            for r in v:
                for k1 in used_keys.keys():
                    if float(r[k1]) > used_keys[k1]:
                        used_keys[k1] = float(r[k1])
            for k1, v1 in used_keys.items():
                bon_res[k1][0] += 1
                bon_res[k1][1] += float(v1)
    return res, bon_res

if __name__ == '__main__':
    print("-----------------------")
    print("Poem evaluate.")
    input_path = args.json_path
    output_path = args.output_path

    data = read_json_lines(input_path)

    res_dic = {
            "domain": {},
            "ticai": {},
            "sub_domain": {},
        }
    type_dic = {
        "domain": [],
        "ticai": [],
        "sub_domain": [],
    }

    score_name_list = ["行数", "每行字数", "总字数", "押韵", "平仄", "标题", "风格类型", "创作对象", "场景描写", "表达主题", "包含元素"]

    raw_header = ["level", "类型/词牌"]
    for s in score_name_list:
        raw_header.append(s)
    if args.bon_res:
        for s in score_name_list:
            raw_header.append("bon_" + s)
    raw_header.append("count")
    out_df = pd.DataFrame(columns=raw_header)

    for d in data:
        if "conversations" in d:
            prompt = d["conversations"][0]["value"]
        else:
            prompt = d["prompt"][0]["value"]

        rewards = d["rewards"]
        if "reward_content" in d:
            for k, v in d["reward_content"].items():
                rewards[k] = v

        domain = "创作"
        ticai =  d["tags"]["体裁"]["体裁"]
        if d["tags"]["体裁"]["体裁"] == "宋词":
            if d["tags"]["体裁"]["词牌名"] in useful_cipai:
                sub_domain = "宋词_常用词牌"
            else:
                sub_domain =  "宋词_非常用词牌"
        else:
            sub_domain = ""
        # if d["tags"]["体裁"]["体裁"] == "宋词":
        #     cipai = d["tags"]["体裁"]["词牌名"]
        # else:
        #     cipai = ""

        if domain not in res_dic["domain"]:
            res_dic["domain"][domain] = {"all": []}
        if ticai not in res_dic["ticai"]:
            res_dic["ticai"][ticai] = {"all": []}
        if sub_domain not in res_dic["sub_domain"]:
            res_dic["sub_domain"][sub_domain] = {"all": []}
        # if source not in res_dic["source"]:
        #     res_dic["source"][source] = {"all": []}
        # if cipai != "" and cipai not in res_dic["词牌名"]:
        #     res_dic["词牌名"][cipai] = {"all": []}

        if domain not in res_dic["domain"][domain]:
            res_dic["domain"][domain][prompt] = []
        if ticai not in res_dic["ticai"][ticai]:
            res_dic["ticai"][ticai][prompt] = []
        # if source not in res_dic["source"][source]:
        #     res_dic["source"][source][prompt] = []
        if sub_domain != "" and prompt not in res_dic["sub_domain"][sub_domain]:
            res_dic["sub_domain"][sub_domain][prompt] = []
        # if cipai != "" and prompt not in res_dic["词牌名"][cipai]:
        #     res_dic["词牌名"][cipai][prompt] = []

        res_dic["domain"][domain][prompt].append(rewards)
        res_dic["domain"][domain]["all"].append(rewards)
        res_dic["ticai"][ticai][prompt].append(rewards)
        res_dic["ticai"][ticai]["all"].append(rewards)
        # res_dic["source"][source][prompt].append(rewards)
        # res_dic["source"][source]["all"].append(rewards)
        if sub_domain != "":
            res_dic["sub_domain"][sub_domain][prompt].append(rewards)
            res_dic["sub_domain"][sub_domain]["all"].append(rewards)
        # if cipai != "":
        #     res_dic["词牌名"][cipai][prompt].append(rewards)
        #     res_dic["词牌名"][cipai]["all"].append(rewards)

        if domain not in type_dic["domain"]:
            type_dic["domain"].append(domain)
        if ticai not in type_dic["ticai"]:
            type_dic["ticai"].append(ticai)
        # if source not in type_dic["source"]:
        #     type_dic["source"].append(source)
        if sub_domain != "" and sub_domain not in type_dic["sub_domain"]:
            type_dic["sub_domain"].append(sub_domain)
        # if cipai != "" and cipai not in type_dic["词牌名"]:
        #     type_dic["词牌名"].append(cipai)
    print("finish process data")

    for concrete_type in list(type_dic.keys()):
        for t in type_dic[concrete_type]:
            # try:
            one_data = {}
            # one_data = {'idx': 0}
            one_data["level"] = concrete_type
            one_data["类型/词牌"] = t
  
            usable_rates, bon_res = get_usable_rate(res_dic[concrete_type][t])
            for rate_name in score_name_list:
                rate_res = usable_rates[rate_name]
                if rate_res[0] == 0:
                    one_data[rate_name]  = ""
                else:
                    one_data[rate_name] = round(1.0 * rate_res[1] / rate_res[0], 2)
                if args.bon_res:
                    bon_rate_res = bon_res[rate_name]
                    if bon_rate_res[0] == 0:
                        one_data["bon_" + rate_name]  = ""
                    else:
                        one_data["bon_" + rate_name] = round(1.0 * bon_rate_res[1] / bon_rate_res[0], 2)
            one_data["count"] = len(res_dic[concrete_type][t]["all"])
            out_df = pd.concat([out_df, pd.DataFrame(one_data, index=[0])], axis=0, ignore_index=True)
            # except:
            #     print("wrong:", t)
    
    out_df.to_excel(output_path, index=False)
