# 字数分
import json
import re
import jsonlines
import argparse
from utils import read_jsonl, save_jsonl

parser = argparse.ArgumentParser()
parser.add_argument("--json-path", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)
args = parser.parse_args()

def is_eng_character(c):
    return (c >= 'a' and c <= 'z') or (c >= "A" and c <= "Z")
def is_seperate(c):
    flag = is_eng_character(c) or c in ['-', '_', '·']
    return not flag
def get_token_count(token_type, ans):
    assert token_type in ["汉字", "英文单词"]
    if token_type == "汉字":
        pattern = re.compile(r'[\u4e00-\u9fa5]')
        chinese_count = len(re.findall(pattern, ans))
        return chinese_count
    # 英文单词
    count = 0
    is_in_word = False
    for c in ans:
        if is_in_word == False and is_eng_character(c):
            count += 1
            is_in_word = True
        elif is_in_word == True and is_seperate(c):
            is_in_word = False
    return count

def get_reward(count, relation, token_type, ans_list):
    assert relation in ["严格等于", "大于", "小于"]
    assert count > 1
    reward = []
    token_counts = []
    for ans in ans_list:
        token_count = get_token_count(token_type, ans)
        if relation == "严格等于":
            score = 2.0 * min(token_count, count) / (token_count + count)
        elif relation == "大于":
            score = 1.0 if token_count > count else 2.0 * token_count / (token_count + count + 1)
        else: # 小于
            score = 1.0 if token_count < count else 2.0 * (count - 1) / (token_count + count - 1)
        reward.append(score)
        token_counts.append(token_count)
    return reward, token_counts

print("-----------------------")
print("Slogan reward count.")
data = []
read_jsonl(args.json_path, data)

for d in data:
    if "rewards" not in d.keys():
        d["rewards"] = {}
    if "explanation" not in d.keys():
        d["explanation"] = {}
    d["rewards"]["count"], d["explanation"]["count"] = get_reward(d["tags"]["字数"], d["tags"]["关系"], d["tags"]["粒度"], d["decom_dic"]["内容"])

save_jsonl(data, args.output_path)
