import random
import json
import os
import datetime
import argparse
import jsonlines
from reward_func import poem_reward, reward_combine
import re
from tqdm import *

parser = argparse.ArgumentParser()
parser.add_argument("--json-path", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)
args = parser.parse_args()

# with open(args.json_path, "r") as f:
#     input_data = json.load(f)

with jsonlines.open(args.json_path, "r") as f:
    input_data = []
    for line in f:
        input_data.append(line)

def get_reward(data):
    content = data["decom_dic"]
    tags = data["tags"]
    rewards, ticai_res_dic = poem_reward(content, tags["体裁"])
    data["rewards"] = rewards
    data["final_reward"] = reward_combine(rewards)
    data["体裁相关内容"] = ticai_res_dic
    return 

output_data = []
for data in tqdm(input_data):
    prompt = data["conversations"][0]["value"]
    data["decom_dic"] = {}
    gen = data["gen"]
    index_left, index_right = gen.find("《"), gen.find("》")
    assert index_left != -1 and index_right != -1 and index_right-index_left>1
    data["decom_dic"]["标题"] = gen[index_left+1:index_right]
    gen = gen[index_right+1:]
    data["decom_dic"]["诗歌"] = []
    notes = '[，, 、, 。, ？, ！, ,, ：, “, ”, \", ", 《, 》, \n,  ]'
    data["decom_dic"]["诗歌"] = re.split(notes, gen)
    data["decom_dic"]["诗歌"] = [s for s in data["decom_dic"]["诗歌"] if s != ""]
    get_reward(data)
    output_data.append(data)

# with open(args.output_path, "w", encoding="utf-8") as f:
#     output_data = json.dumps(output_data, ensure_ascii=False)
#     f.write(output_data)

with open(args.output_path, "w") as file:
    for d in output_data:
        json_data = json.dumps(d, ensure_ascii=False)
        file.write(json_data + "\n")
