import numpy as np
from utils import read_json
from pypinyin import lazy_pinyin, Style

rhyme_dic = {
    "阴平一东": "东同童僮铜桐峒筒瞳中衷忠盅虫冲终忡崇嵩崧菘戎绒弓躬宫穹融雄熊穷冯风枫疯丰充隆窿空公功工攻蒙蒙朦瞢笼胧栊咙聋珑砻泷蓬篷洪荭红虹鸿丛翁嗡匆葱聪骢通棕烘崆", 
    "阴平二冬": "冬咚彤农侬宗淙锺钟龙茏舂松淞冲容榕蓉溶庸佣慵封胸凶匈汹雍邕痈浓脓重从逢缝峰锋丰蜂烽葑纵踪茸蛩邛筇跫供蚣喁",
    "阴平五微": "微薇晖辉徽挥韦围帏违闱霏菲妃飞非扉肥威祈畿机讥玑稀希衣依归饥矶欷诽绯晞葳巍沂圻颀",
    "阴平十灰": "灰恢魁隈回徊槐梅枚玫媒煤雷颓崔催摧堆陪杯醅嵬推诙裴培盔偎煨瑰茴追胚徘坯桅傀儡莓开哀埃台苔抬该才材财裁栽哉来莱灾猜孩徕骀胎唉垓挨皑呆腮",
    "阴平十四寒": "寒韩翰丹单安鞍难餐檀坛滩弹残干肝竿阑栏澜兰看刊丸完桓纨端湍酸团攒官观鸾銮峦冠欢宽盘蟠漫叹邯郸摊玕拦珊狻鼾杆跚姗殚箪瘅谰獾倌棺剜潘拼盘般蹒瘢磐瞒谩馒鳗钻抟邗汗",
    "阳平一先": "先前千阡笺天坚肩贤弦烟燕莲怜连田填巅鬈宣年颠牵妍研眠渊涓捐娟边编悬泉迁仙鲜钱煎然延筵毡旃蝉缠廛联篇偏绵全镌穿川缘鸢旋船涎鞭专圆员乾虔愆权拳椽传焉嫣鞯褰搴铅舷跹鹃筌痊诠悛先邅禅婵躔颛燃涟琏便翩骈癫阗钿沿蜒胭芊鳊胼滇佃畋咽湮狷蠲蔫骞膻扇棉拴荃籼砖挛儇璇卷扁溅犍",
    "阳平五歌": "歌多罗河戈阿和波科柯陀娥蛾鹅萝荷何过磨螺禾珂蓑婆坡呵哥轲沱鼍拖驼跎佗颇峨俄摩么娑莎迦疴苛蹉嵯驮箩逻锣哪挪锅诃窠蝌髁倭涡窝讹陂鄱皤魔梭唆骡挼靴瘸搓哦瘥酡",
    "阳平八庚": "庚更羹盲横觥彭亨英烹平枰京惊荆明盟鸣荣莹兵兄卿生甥笙牲擎鲸迎行衡耕萌甍宏闳茎罂莺樱泓橙争筝清情晴精睛菁晶旌盈楹瀛嬴赢营婴缨贞成盛城诚呈程酲声征正轻名令并倾萦琼峥嵘撑粳坑铿撄鹦黥蘅澎膨棚浜坪苹钲伧檠嘤轰铮狰宁狞瞪绷怦璎砰氓鲭侦柽蛏茔赪茕赓黉瞠", 
    "阳平九青": "青经泾形陉亭庭廷霆蜓停丁仃馨星腥醒惺俜灵龄玲铃伶零听冥溟铭瓶屏萍荧萤荣扃垧蜻硎苓聆瓴翎娉婷宁暝瞑螟猩钉疔叮厅町泠棂囹羚蛉咛型邢",
    "上声一董": "董懂动孔总笼拢桶捅蓊蠓汞",
    "上声四纸": "纸只咫是靡彼毁委诡髓累技绮觜此泚蕊徙尔弭婢侈弛豕紫旨指视美否痞兕几姊比水轨止徴市喜已纪跪妓蚁鄙晷子仔梓矢雉死履垒癸趾址以已似耜祀史驶耳使里理李起杞圯跂士仕俟始齿矣耻麂枳峙鲤迩氏玺巳滓苡倚匕迤逦旖旎舣蚍秕芷拟你企诔捶屣棰揣豸祉恃", 
    "上声八荠": "荠礼体米启陛洗邸底抵弟坻柢涕悌济澧醴诋眯娣棨递昵睨蠡",
    "上声十四旱": "旱暖管琯满短馆缓盥碗懒伞伴卵散伴诞罕瀚断侃算款但坦袒纂缎拌懑谰莞",
    "去声一送": "送梦凤洞众瓮贡弄冻痛栋恸仲中粽讽空控哄赣",
    "去声九泰": "泰太带外盖大濑赖籁蔡害蔼艾丐奈柰汰癞霭会旆最贝沛霈绘脍荟狈侩桧蜕酹外兑",
    "去声十卦": "卦挂画懈廨邂隘卖派债怪坏诫戒界介芥械薤拜快迈败稗晒瀣湃寨疥届蒯篑蒉喟聩块惫",
}

rhyme_list = ["阴平一东", "阴平二冬", "阴平五微", "阴平十灰", "阴平十四寒", "阳平一先", "阳平五歌", "阳平八庚", "阳平九青", "上声一董", "上声四纸", "上声八荠", "上声十四旱", "去声一送", "去声九泰", "去声十卦"]

new_rhyme_dic = {
                    "一麻": ['a', 'ia', 'ua'], 
                    "二波": ['o', 'uo'],
                    "三歌": ['e'],
                    "四皆": ['ie', 'üe', 've'],
                    "五支": ['i'],
                    "六儿": ['er'],
                    "七齐": ['i'],
                    "八微": ['ei', 'ui', 'uei'],
                    "九开": ['ai', 'uai'],
                    "十姑": ['u'],
                    "十一鱼": ['ü', 'v'],
                    "十二侯": ['ou', 'iu', 'iou'],
                    "十三豪": ['ao', 'iao'],
                    "十四寒": ['an', 'ian', 'uan', 'üan', 'van'],
                    "十五痕": ['en', 'in', 'un', 'ün', 'vn', 'uen'],
                    "十六唐": ['ang', 'iang', 'uang'],
                    "十七庚": ['eng', 'ing', 'ueng'],
                    "十八东": ['ong', 'iong'],
                }

new_14_rhyme_dic = {
                    "一麻": ['a', 'ia', 'ua'], 
                    "二波": ['o', 'e', 'uo'],
                    "三皆": ['ie', 'üe', 've'],
                    "四开": ['ai', 'uai'],
                    "五微": ['ei', 'ui', 'uei'],
                    "六豪": ['ao', 'iao'],
                    "七尤": ['ou', 'iu', 'iou'],
                    "八寒": ['an', 'ian', 'uan', 'üan', 'van'],
                    "九文": ['en', 'in', 'un', 'ün', 'vn', 'uen'],
                    "十唐": ['ang', 'iang', 'uang'],
                    "十一庚": ['eng', 'ing', 'ueng', 'ong', 'iong'],
                    "十二齐": ['i', 'er', 'ü', 'v'],
                    "十三支": ['i'],
                    "十四姑": ['u'],
                }

all_yunmus = ['a', 'ia', 'ua', 'o', 'uo', 'e', 'ie', 'üe', 've', 'i', 'er', '-i', 'ei', 'ui', 'uei', 'ai', 'uai', 'u', 'ü', 'v', 'ou', 'iu', 'iou', 'ao', 'iao', 'an', 'ian', 'uan', 'üan', 'van', 'en', 'in', 'un', 'ün', 'vn', 'uen', 'ang', 'iang', 'uang', 'eng', 'ing', 'ueng', 'ong', 'iong']

rhyme_i = [['zh', 'ch', 'sh', 'z', 'c', 's', 'r'], 
           ['b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'j', 'q', 'x', 'y']]

weight = {
        "标题": 0.5,
        "行数": 1,
        "每行字数": 1,
        "总字数": 1,
        "押韵": 1,
        "平仄": 1,
        "开头": 1,
        "结尾": 1,
        "包含": 1,
        "不包含": 1,
        "藏头": 1,
        "藏中-字数分": 1,
        "藏中-顺序分": 1,
}

def load_songci_dic(name):
    songci_75 = read_json("")
    if name in songci_75:
        return True, songci_75
    else:
        songci_151 = read_json("")
        return False, songci_151
    
def load_songci_rhyme_dic(name, variant):
    songci_rhyme = read_json("")
    res = {}
    for v in variant:
        res[v] = songci_rhyme[name][v]
    return res

def load_tone_dic(ticai, name, variant_set=None):
    if ticai == "绝句" and name == "五言":
        dic = read_json("")
        return dic["五绝"]
    elif ticai == "绝句" and name == "七言":
        dic = read_json("")
        return dic["七绝"]
    elif ticai == "律诗" and name == "五言":
        dic = read_json("")
        return dic["五律"]
    elif ticai == "律诗" and name == "七言":
        dic = read_json("")
        return dic["七律"]
    elif ticai == "宋词" and variant_set is not None:
        dic = read_json("")
        res_dic = {}
        for v in variant_set:
            res = dic[name][v]
            tone_list_all = []
            for r in res:
                tone_list_all.extend(r)
            res_dic[v] = tone_list_all
        return res_dic

def poem_reward(content, tags):
    reward_dic = {
        "标题": None,
        "行数": None,
        "每行字数": None,
        "总字数": None,
        "押韵": None,
        "平仄": None,
        "开头": None,
        "结尾": None,
        "包含": None,
        "不包含": None,
        "藏头": None,
        "藏中-字数分": None,
        "藏中-顺序分": None,
    }
    poem = content["诗歌"]
    if len(poem) == 0:
        raise ValueError("Wrong content", content)
    title_reward(content, reward_dic, tags)

    assert "体裁" in tags
    ticai_res_dic = ticai_reward(poem, reward_dic, tags)

    if "总字数" in tags:
        total_count_reward(poem, reward_dic, tags["总字数"])

    if "包含" in tags:
        contain_reward(poem, reward_dic, tags["包含"])

    if "不包含" in tags:
        not_contain_reward(poem, reward_dic, tags["不包含"])
    
    if "藏头" in tags:
        head_reward(poem, reward_dic, tags["藏头"])

    if "藏中" in tags:
        body_reward(poem, reward_dic, tags["藏中"])

    if "开头" in tags:
        start_reward(poem, reward_dic, tags["开头"])

    if "结尾" in tags:
        tail_reward(poem, reward_dic, tags["结尾"])
    
    if "风格" in tags:
        type_reward(poem, reward_dic, tags["风格"])
    return reward_dic, ticai_res_dic

def title_reward(content, reward_dic, tags):
    if content["标题"] == "":
        reward_dic["标题"] = 0.0
    # elif "主题" in tags and tags["主题"] in content["标题"]:
    #     reward_dic["标题"] = 0.5
    else:
        reward_dic["标题"] = 1.0
    if "藏头" in tags and tags["藏头"] in content["标题"]:
        reward_dic["标题"] = 0.2
    return

def ticai_reward(poem, reward_dic, tags): 
    ## 体裁奖励包括行数，字数，押韵和平仄
    tag = tags["体裁"]
    word_limit = 0
    word_limit_lists = []
    rhyme_word_list = []
    rhyme_word_dic = {}
    word_limits = []
    rhyme_words_res = {}
    use_variant = set()
    tone_res = ["", ""]

    ## 字数奖励
    if tag == "绝句" or tag == "律诗":
        if "五七言" in tags:
            if tags["五七言"] == "五言":
                word_limit = 5
            elif tags["五七言"] == "七言":
                word_limit = 7
            else:
                raise ValueError("Wrong value for 五七言", tags["五七言"])
        else:
            word_limits = [5, 7]
    elif tag == "古体诗" and "每行字数" in tags:
        if tags["每行字数"] == "四言":
            word_limit = 4
        elif tags["每行字数"] == "五言":
            word_limit = 5
        elif tags["每行字数"] == "六言":
            word_limit = 6
        elif tags["每行字数"] == "七言":
            word_limit = 7
        elif tags["每行字数"] == "八言":
            word_limit = 8
        else:
            raise ValueError("Wrong value for 字数", tags["每行字数"])
    elif tag == "宋词":
        assert "词牌名" in tags
        is_songci_75, songci_dic = load_songci_dic(tags["词牌名"])
        # if "词牌名" in tags:
        #     word_limit_lists = songci_dic[tags["词牌名"]]
        # else:
        #     raise ValueError("No cipaiming", poem)
        if is_songci_75:
            for k, v in songci_dic[tags["词牌名"]].items():
                word_limit_lists.append(v)
        else:
            word_limit_lists = songci_dic[tags["词牌名"]]
        assert len(word_limit_lists) > 0
    elif tag == "现代诗":
        # ## 版本1：各行统计字数，独特的字数个数占总行数比例
        # total_hang = len(poem)
        # if total_hang > 0:
        #     lines_count_set = set([len(x) for x in poem])
        #     reward_dic["每行字数"] = len(lines_count_set) / total_hang
        ## 版本2：统计每行和上一行的字数，统计不相等的行数的比例
        if len(poem) <= 1:
            reward_dic["每行字数"] = 1.0
        else:
            total_hang = len(poem) - 1
            same_hang = 0
            for i in range(total_hang):
                if len(poem[i]) == len(poem[i+1]):
                    same_hang += 1
            reward_dic["每行字数"] = 1.0 - 1.0 * same_hang / total_hang
    if word_limit > 0:
        all_lenght = len(poem)
        correct_count = 0
        for x in poem:
            if len(x) == word_limit:
                correct_count += 1
        reward_dic["每行字数"] = 1.0 * correct_count / all_lenght
    if len(word_limits) > 0:
        all_lenght = len(poem)
        correct_count = 0
        wuqiyan_dic = {
            5: "五言",
            7: "七言"
        }
        for wl in word_limits:
            cc = 0
            for x in poem:
                if len(x) == wl:
                    cc += 1
            if cc >= correct_count:
                correct_count = cc
            tags["五七言"] = wuqiyan_dic[wl]
        reward_dic["每行字数"] = 1.0 * correct_count / all_lenght
    if len(word_limit_lists) > 0:
        max_reward = 0.0
        for word_limit_list in word_limit_lists:
            word_limit_list_all = []
            for cur_list in word_limit_list:
                word_limit_list_all.extend(cur_list)
            size_min = min(len(word_limit_list_all), len(poem))
            size_max = max(len(word_limit_list_all), len(poem))
            correct_count = 0
            for i in range(size_min):
                if len(poem[i]) == word_limit_list_all[i]:
                    correct_count += 1
            current_reward = 1.0 * correct_count / size_max
            if current_reward > max_reward:
                max_reward = current_reward
                if is_songci_75:
                    for k, v in songci_dic[tags["词牌名"]].items():
                        if v == word_limit_list:
                            use_variant = set()
                            use_variant.add(k)
            if current_reward == max_reward and is_songci_75:
                for k, v in songci_dic[tags["词牌名"]].items():
                    if v == word_limit_list:
                        use_variant.add(k)
        reward_dic["每行字数"] = max_reward

    ## 行数奖励
    if tag == "绝句" or tag == "律诗" or tag == "宋词":
        if tag == "绝句":
            line_limit = 4
            line_gap = np.abs(len(poem) - line_limit)
            reward_dic["行数"] = 1.0 - 1.0 / line_limit * min(line_limit, line_gap)
        elif tag == "律诗":
            line_limit = 8
            line_gap = np.abs(len(poem) - line_limit)
            reward_dic["行数"] = 1.0 - 1.0 / line_limit * min(line_limit, line_gap)
        else:
            if is_songci_75:
                line_limits = []
                songci_line_dic = {}
                for uv in use_variant:
                    lines = np.sum([len(x) for x in songci_dic[tags["词牌名"]][uv]])
                    if lines not in songci_line_dic:
                        songci_line_dic[lines] = []
                    songci_line_dic[lines].append(uv)
                    if lines not in line_limits:
                        line_limits.append(lines)
            else:
                line_limits = [np.sum([len(x) for x in word_limit_list]) for word_limit_list in word_limit_lists]
            line_reward = 0.0
            use_variant = set()
            for cur_limit in line_limits:
                cur_gap = np.abs(len(poem) - cur_limit)
                cur_reward = 1.0 - 1.0 / cur_limit * min(cur_limit, cur_gap)
                if cur_reward > line_reward:
                    line_reward = cur_reward
                    if is_songci_75:
                        use_variant = set(songci_line_dic[cur_limit])
                elif cur_reward == line_reward and is_songci_75:
                    use_variant.update(songci_line_dic[cur_limit])
            reward_dic["行数"] = line_reward
    else:
        if "行数" in tags:
            hang_tag = tags["行数"]
            hang_tag = hang_tag.removesuffix("行")
            if "不少于" in hang_tag:
                num = int(hang_tag.removeprefix("不少于"))
                line_gap = num - len(poem)
                if line_gap <= 0:
                    reward_dic["行数"] = 1.0
                else:
                    reward_dic["行数"] = 1.0 - 1.0 / num * line_gap
            elif "不多于" in hang_tag:
                num = int(hang_tag.removeprefix("不多于"))
                line_gap = len(poem) - num
                if line_gap <= 0:
                    reward_dic["行数"] = 1.0
                else:
                    reward_dic["行数"] = 1.0 - 1.0 / num * min(num, line_gap)
            else:
                num = int(hang_tag)
                line_gap = np.abs(num - len(poem))
                if line_gap == 0:
                    reward_dic["行数"] = 1.0
                else:
                    reward_dic["行数"] = 1.0 - 1.0 / num * min(num, line_gap)
    
    ## 平仄奖励
    if tag == "绝句" or tag == "律诗":
        assert "五七言" in tags
        tone_list = load_tone_dic(tag, tags["五七言"])
        reward_dic["平仄"], tone_res = check_tone(poem, tone_list)
    elif tag == "宋词" and is_songci_75:
        tone_dic = load_tone_dic(tag, tags["词牌名"], use_variant)
        reward_dic["平仄"], tone_res = check_tone(poem, tone_dic, use_dic=True)
        use_variant = tone_res[-1]
        tone_res = tone_res[:-1]

    ## 押韵奖励
    if tag == "绝句":
        rhyme = ""
        for i in [1, 3]:
            if i < len(poem):
                rhyme_word_list.append(poem[i][-1])
        if "押韵" in tags:
            rhyme = tags["押韵"]
    elif tag == "律诗":
        rhyme = ""
        for i in range(len(poem) // 2):
            rhyme_word_list.append(poem[2 * i + 1][-1])
        if "押韵" in tags:
            rhyme = tags["押韵"]
    elif tag == "宋词":
        if is_songci_75:
            rhyme = ""
            rhyme_dic = load_songci_rhyme_dic(tags["词牌名"], use_variant)
            for v, rhyme_res in rhyme_dic.items():
                cur_rhyme_word_list = []
                rhyme_list_all = []
                for cur_r in rhyme_res:
                    rhyme_list_all.extend(cur_r)
                size_min = min(len(poem), len(rhyme_list_all))
                for i in range(size_min):
                    if rhyme_list_all[i]:
                        cur_rhyme_word_list.append(poem[i][-1])
                rhyme_word_dic[v] = cur_rhyme_word_list
            if "押韵" in tags:
                rhyme = tags["押韵"]
    elif "押韵" in tags:
        rhyme = ""
        for i in range(len(poem)):
            rhyme_word_list.append(poem[i][-1])
            rhyme = tags["押韵"]
    if len(rhyme_word_list) > 0:
        reward_dic["押韵"], rhyme_words_res = check_rhyme(poem, rhyme_word_list, rhyme, use_14_yun=True)
    if len(rhyme_word_dic) > 0:
        use_variant = set()
        rhyme_reward = 0.0
        for k, rwl in rhyme_word_dic.items():
            if len(rwl) <= 0:
                continue
            cur_reward, rhyme_words_res = check_rhyme(poem, rwl, rhyme, use_14_yun=True)
            if cur_reward > rhyme_reward:
                rhyme_reward = cur_reward
                use_variant = set()
                use_variant.add(k)
            if cur_reward == rhyme_reward:
                use_variant.add(k)
        reward_dic["押韵"] = rhyme_reward
    
    return {
        "押韵结果": rhyme_words_res,
        # "使用平仄变体": tone_res[0],
        # "平仄结果": tone_res[1],
        "使用宋词变体": list(use_variant)
    }

def total_count_reward(poem, reward_dic, tag):
    word_tag = tag.removesuffix("字")
    count = np.sum([len(x) for x in poem])
    if "不少于" in word_tag:
        num = int(word_tag.removeprefix("不少于"))
        word_gap = num - count
        if word_gap <= 0:
            reward_dic["总字数"] = 1.0
        else:
            reward_dic["总字数"] = 1.0 - 1.0 / num * word_gap
    elif "不多于" in word_tag:
        num = int(word_tag.removeprefix("不多于"))
        word_gap = count - num
        if word_gap <= 0:
            reward_dic["总字数"] = 1.0
        else:
            reward_dic["总字数"] = 1.0 - 1.0 / num * min(num, word_gap)
    else:
        num = int(word_tag)
        word_gap = np.abs(num - count)
        if word_gap == 0:
            reward_dic["总字数"] = 1.0
        else:
            reward_dic["总字数"] = 1.0 - 1.0 / num * min(num, word_gap)

def contain_reward(poem, reward_dic, tag): 
    valid = False
    for l in poem:
        if tag in l:
            valid = True
            break
    if valid:
        reward_dic["包含"] = 1.0
    else:
        reward_dic["包含"] = 0.0
    return

def not_contain_reward(poem, reward_dic, tag): 
    valid = True
    for l in poem:
        if tag in l:
            valid = False
            break
    if valid:
        reward_dic["不包含"] = 1.0
    else:
        reward_dic["不包含"] = 0.0
    return

def head_reward(poem, reward_dic, tag): 
    notes = ['，', '、', '。', '？', '！', ',', '：', '“', '”', '\"', '"', "《", "》"]
    tag = [c for c in tag if c not in notes]
    total = len(tag)
    correct_count = 0
    size = min(len(poem), len(tag))
    for i in range(size):
        if tag[i] == poem[i][0]:
            correct_count += 1
    reward_dic["藏头"] = 1.0 * correct_count / total
    return

def body_reward_origin(poem, reward_dic, tag):
    assert len(tag) > 1
    position_info = [-1 for _ in range(len(tag))]
    index = 0
    previous_sentence = ""
    sentence_same = True
    for sentence in poem:
        for token in sentence:
            for i in range(len(tag)):
                if token == tag[i] and position_info[i] == -1:
                    position_info[i] = index
                    if previous_sentence != "" and sentence != previous_sentence:
                        sentence_same = False
                    previous_sentence = sentence
                    break
            index += 1
    reward_dic["藏中"] = 0.
    for i in range(len(tag)):
        token_reward = 1.0 / len(tag) if position_info[i] != -1 else 0
        if i > 0 and position_info[i] < position_info[i - 1]:
            token_reward *= 0.5
        reward_dic["藏中"] += token_reward
    if sentence_same:
        reward_dic["藏中"] *= 0.5

def cal(index_list, list_len):
    count = 0
    for i in range(list_len):
        if index_list[i] == -1:
            continue
        for j in range(i + 1, list_len):
            if index_list[j] == -1:
                continue
            if index_list[i] < index_list[j]:
                count += 1
    return count

def body_reward(poem, reward_dic, tag):
    tag_len = len(tag)
    position_info = [[-1] for _ in range(4)]
    index = 0
    for sentence in poem:
        for token in sentence:
            for i in range(tag_len):
                if tag[i] == token:
                    position_info[i].append(index)
                    break
            index += 1
    reward_dic["藏中-字数分"] = 0.
    reward_dic["藏中-顺序分"] = 0.
    for i in range(len(tag)):
        if len(position_info[i]) > 1:
            reward_dic["藏中-字数分"] += 1./tag_len
    max_order_pair = tag_len*(tag_len - 1)/2.
    for index0 in position_info[0]:
        for index1 in position_info[1]:
            for index2 in position_info[2]:
                for index3 in position_info[3]:
                    reward_dic["藏中-顺序分"] = max(reward_dic["藏中-顺序分"], cal([index0, index1, index2, index3], tag_len)/max_order_pair)

def start_reward(poem, reward_dic, tag): 
    total = len(tag)
    correct_count = 0
    first = poem[0]
    size = min(len(first), len(tag))
    for i in range(size):
        if tag[i] == first[i]:
            correct_count += 1
    reward_dic["开头"] = 1.0 * correct_count / total
    return

def tail_reward(poem, reward_dic, tag): 
    total = len(tag)
    correct_count = 0
    last = poem[-1]
    size = min(len(last), len(tag))
    for i in range(size):
        if tag[-i-1] == last[-i-1]:
            correct_count += 1
    reward_dic["结尾"] = 1.0 * correct_count / total
    return

def type_reward(poem, reward_dic, tag): 
    return

def reward_combine(reward_dic):
    final_reward = 0
    for k, v in reward_dic.items():
        if v is not None:
            final_reward += 1.0 * v * weight[k]
    return final_reward

def check_rhyme(poem, last_words, rhyme, use_strict_rule=False, use_14_yun=False):
    if len(last_words) <= 0:
        return 0.0
    total_size = len(last_words)
    last_words_rhymes = {}
    if rhyme in rhyme_list:
        correct_count = 0
        rhyme_str = rhyme_dic[rhyme]
        for w in last_words:
            if w in rhyme_str:
                correct_count += 1
                if rhyme not in last_words_rhymes:
                    last_words_rhymes[rhyme] = []
                last_words_rhymes[rhyme].append(w)
        return 1.0 * correct_count / total_size, last_words_rhymes
    elif rhyme in all_yunmus:
        last_words_rhymes[rhyme] = []
        for w in last_words:
            yunmu = lazy_pinyin(w, style=Style.FINALS)[0]
            yun = None

            if yunmu == 'i':
                shengmu = lazy_pinyin(w, style=Style.INITIALS)[0]
                if shengmu in ['zh', 'ch', 'sh', 'z', 'c', 's', 'r']:
                    yun = "-i"
                else:
                    yun = "i"
            else:
                yun = yunmu
            if yun == rhyme:
                last_words_rhymes[rhyme].append(w)
        return 1.0 * len(last_words_rhymes[rhyme]) / total_size, last_words_rhymes
    else:
        fix_yun = rhyme
        for w in last_words:
            yunmu = lazy_pinyin(w, style=Style.FINALS)[0]
            yun = None

            if use_strict_rule:
                if yunmu == 'i':
                    shengmu = lazy_pinyin(w, style=Style.INITIALS)[0]
                    if shengmu in ['zh', 'ch', 'sh', 'z', 'c', 's', 'r']:
                        yun = "-i"
                    else:
                        yun = "i"
                else:
                    yun = yunmu
                if fix_yun != "" and yun != fix_yun:
                    yun = ""
            elif use_14_yun:
                if yunmu == 'i':
                    shengmu = lazy_pinyin(w, style=Style.INITIALS)[0]
                    if shengmu in ['zh', 'ch', 'sh', 'z', 'c', 's', 'r']:
                        yun = "十三支"
                    else:
                        yun = "十二齐"
                else:
                    for k, v in new_14_rhyme_dic.items():
                        if yunmu in v:
                            yun = k
                            break
                if fix_yun != "" and yun != fix_yun:
                    yun = ""
            else:
                if yunmu == 'i':
                    shengmu = lazy_pinyin(w, style=Style.INITIALS)[0]
                    if shengmu in ['zh', 'ch', 'sh', 'z', 'c', 's', 'r']:
                        yun = "五支"
                    else:
                        yun = "七齐"
                else:
                    for k, v in new_rhyme_dic.items():
                        if yunmu in v:
                            yun = k
                            break
                if fix_yun != "" and yun != fix_yun:
                    yun = ""
            if yun is None:
                print(poem)
                print(last_words)
                print("Wrong words", w, yunmu)
                continue
            if yun not in last_words_rhymes:
                last_words_rhymes[yun] = []
            if yun != "":
                last_words_rhymes[yun].append(w)
        if len(last_words_rhymes) > 0:
            correct_yun = max([len(tuple(v)) for v in last_words_rhymes.values()])
        else:
            correct_yun = 0.0
        return 1.0 * correct_yun / total_size, last_words_rhymes

def check_tone(poem, tones, use_dic=False):
    chosen_tone = None
    correct_res = []
    tone_reward = 0.0
    total_word_count = np.sum([len(s) for s in poem])
    tone_dic = {
        "平": [0, 1, 2],
        "仄": [0, 3, 4],
        "中": [0, 1, 2, 3, 4]
    }

    if use_dic:
        use_keys = set()
        for k, tone in tones.items():
            correct_tone_words = 0
            cur_correct_res = []
            
            size_min = min(len(poem), len(tone))
            for i in range(size_min):
                tone_s = tone[i]
                p_s = poem[i]
                p_s_tones = lazy_pinyin(p_s, style=Style.TONE3)
                cur_line_res = []
                for j in range(min(len(tone_s), len(p_s_tones))):
                    if p_s_tones[j][-1] in ['1', '2', '3', '4']:
                        cur_t = int(p_s_tones[j][-1])
                    else:
                        cur_t = 0
                    if cur_t in tone_dic[tone_s[j]]:
                        correct_tone_words += 1
                        cur_line_res.append(1)
                    else:
                        cur_line_res.append(0)
                cur_correct_res.append(cur_line_res)
            cur_reward = 1.0 * correct_tone_words / total_word_count
            if cur_reward > tone_reward:
                chosen_tone = tone
                tone_reward = cur_reward
                correct_res = cur_correct_res
                use_keys = set()
                use_keys.add(k)
            elif cur_reward == tone_reward:
                use_keys.add(k)
        return tone_reward, [chosen_tone, correct_res, use_keys]
    else:
        for tone in tones:
            correct_tone_words = 0
            cur_correct_res = []
            
            size_min = min(len(poem), len(tone))
            for i in range(size_min):
                tone_s = tone[i]
                p_s = poem[i]
                p_s_tones = lazy_pinyin(p_s, style=Style.TONE3)
                cur_line_res = []
                for j in range(min(len(tone_s), len(p_s_tones))):
                    if p_s_tones[j][-1] in ['1', '2', '3', '4']:
                        cur_t = int(p_s_tones[j][-1])
                    else:
                        cur_t = 0
                    if cur_t in tone_dic[tone_s[j]]:
                        correct_tone_words += 1
                        cur_line_res.append(1)
                    else:
                        cur_line_res.append(0)
                cur_correct_res.append(cur_line_res)
            cur_reward = 1.0 * correct_tone_words / total_word_count
            if cur_reward >= tone_reward:
                chosen_tone = tone
                tone_reward = cur_reward
                correct_res = cur_correct_res
        return tone_reward, [chosen_tone, correct_res]


if __name__ == '__main__':
    content = {
                "标题": "长相思·螺丝",
                "诗歌": [
                    "彩袖殷勤捧玉钟"
            ]
        }
    tags= {
            "体裁": "宋词",
            "词牌名": "采桑子",
            "主题": "奇异博士",
            "包含": "石榴",
            "风格": "李清照"
        }
    reward_dic, ticai_res = poem_reward(content, tags)

    print("reward:", reward_dic)
    print("ticai res:", ticai_res)
    print("final", reward_combine(reward_dic))