import random
import sys
import json
import os
import datetime
from reward_func import poem_reward, reward_combine
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--json-path", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)

args = parser.parse_args()


def read_jsonl(file_name):
    json_data = []
    with open(file_name, 'r', encoding='utf-8') as f:
        for line in f:
            e_d = json.loads(line)
            json_data.append(e_d)
    return json_data

def read_json(file_name):
    with open(file_name, 'r') as file:
        data = json.load(file)
    return data

def save_json(data, file_name):
    json_object = json.dumps(data, ensure_ascii = False)
    with open(file_name, 'w', encoding='utf-8') as f:
        f.write(json_object)

def save_jsonl(data, file_name):
    with open(file_name, "w") as file:
        for d in data:
            json_data = json.dumps(d, ensure_ascii=False)
            file.write(json_data + "\n")


def get_reward(data):
    for one_data in data:
        assert one_data["gpt_res"] == 1
        content = one_data["decom_dic"]
        # content["gen"] = one_data["gen"]
        tags = one_data["tags"]
        rewards, ticai_res_dic = poem_reward(content, tags["体裁"])
        one_data["rewards"] = rewards
        one_data["final_reward"] = reward_combine(rewards)
        one_data["体裁相关内容"] = ticai_res_dic
    return

def check_reward(reward1, reward2, idx1, idx2):
    valid = True
    chosens = []
    rejects = []
    assert len(reward1) == len(reward2)
    for k in reward1.keys():
        assert k in reward2.keys()
        r1 = reward1[k]
        r2 = reward2[k]
        if r1 is None and r2 is None:
            continue
        if r1 is None:
            r1 = 0.0
        if r2 is None:
            r2 = 0.0
        if r1 > r2:
            chosens.append(idx1)
            rejects.append(idx2)
        elif r1 < r2:
            chosens.append(idx2)
            rejects.append(idx1)
    if len(chosens) == 0:
        return False, idx1, idx2
    chosen = chosens[0]
    reject = rejects[0]
    for i in range(len(chosens)):
        if chosens[i] != chosen:
            valid = False
            break
    return valid, chosen, reject

if __name__ == '__main__':
    print("-----------------------")
    print("Poem get reward.")
    data = read_jsonl(args.json_path)
    get_reward(data)
    save_jsonl(data, args.output_path)