import sys
import json
import numpy as np

from tqdm import tqdm
from pathlib import Path

from scorers.automatic import AutomaticScorer


def calculate_metrics(score_item, clinical_department, dataset, gd_dataset, language):
    score_item[clinical_department] = {}
    score_item[clinical_department]['metrics'] = {}

    scorer = AutomaticScorer()
    # ====================================================
    # 计算通过率
    # ====================================================
    passed = 0
    for index in tqdm(range(len(dataset))):
        predicted_clinical_department = dataset[index]['predicted_clinical_department']
        ground_truth_clinical_department = dataset[index]['clinical_department']
        # 重复出现了两个以上直接错误
        flag1 = False
        hit_count = 0
        for _ in clinical_department_zh_list:
            if _ in predicted_clinical_department or predicted_clinical_department in _:
                hit_count += 1
        if (hit_count > 1):
            print(str(hit_count) + str(f", {dataset[index]['clinical_case_uid']}"))
        if hit_count == 1 and (
                ground_truth_clinical_department in predicted_clinical_department or predicted_clinical_department in ground_truth_clinical_department):
            flag1 = True
        predicted_clinical_diagnosis_part = dataset[index]['predicted_principal_diagnosis']
        ground_truth_clinical_diagnosis_part = gd_dataset[index]['ground_truth_principal_diagnosis']
        flag2 = False
        for disease_diagnosis in ground_truth_clinical_diagnosis_part:
            if disease_diagnosis in predicted_clinical_diagnosis_part:
                flag2 = True
                break
        # 同时通过
        if flag1 and flag2:
            passed += 1
    pass_rate = round((passed / len(dataset)) * 100, 2)
    score_item[clinical_department]['comprehensive_diagnostic_accuracy'] = pass_rate

    # ====================================================
    # 计算临床诊断任务分数
    # ====================================================
    for clinical_diagnosis_part in clinical_diagnosis_part_list:
        predictions = []
        references = []
        for index in tqdm(range(len(dataset))):
            predictions.append(dataset[index][f'predicted_{clinical_diagnosis_part}'])
            references.append(dataset[index][clinical_diagnosis_part])
        if len(predictions) != len(references):
            raise Exception(f'### [Unequal length]')
        # 临床诊断子任务指标
        score_item[clinical_department]['metrics'][clinical_diagnosis_part] = {}
        bleu = scorer.calculate_bleu(language, predictions, references)
        score_item[clinical_department]['metrics'][clinical_diagnosis_part]['bleu'] = bleu
        rouge = scorer.calculate_rouge(language, predictions, references)
        score_item[clinical_department]['metrics'][clinical_diagnosis_part]['rouge'] = rouge
        bertscore = scorer.calculate_bertscore(language, predictions, references)
        score_item[clinical_department]['metrics'][clinical_diagnosis_part]['bertscore'] = bertscore
        # 临床诊断子任务分数
        clinical_diagnosis_part_scores = []
        clinical_diagnosis_part_scores.append(
            score_item[clinical_department]['metrics'][clinical_diagnosis_part]['bleu'])
        clinical_diagnosis_part_scores.append(
            score_item[clinical_department]['metrics'][clinical_diagnosis_part]['rouge'])
        clinical_diagnosis_part_scores.append(
            score_item[clinical_department]['metrics'][clinical_diagnosis_part]['bertscore'])
        score_item[clinical_department][clinical_diagnosis_part] = round(float(np.mean(clinical_diagnosis_part_scores)),
                                                                         2)

    # ====================================================
    # 计算影像诊断任务分数
    # ====================================================
    predictions = []
    references = []
    for index in tqdm(range(len(dataset))):
        if isinstance(dataset[index]['imageological_examination'], dict):
            for imageological_examination_part_feature in dataset[index]['imageological_examination'].keys():
                predictions.append(dataset[index]['imageological_examination'][imageological_examination_part_feature][
                                       'predicted_impression'])
                references.append(
                    dataset[index]['imageological_examination'][imageological_examination_part_feature]['impression'])
    if len(predictions) != len(references):
        raise Exception(f'### [Unequal length]')
    # 影像诊断任务指标
    score_item[clinical_department]['metrics']['imaging_diagnosis'] = {}
    bleu = scorer.calculate_bleu(language, predictions, references)
    score_item[clinical_department]['metrics']['imaging_diagnosis']['bleu'] = bleu
    rouge = scorer.calculate_rouge(language, predictions, references)
    score_item[clinical_department]['metrics']['imaging_diagnosis']['rouge'] = rouge
    bertscore = scorer.calculate_bertscore(language, predictions, references)
    score_item[clinical_department]['metrics']['imaging_diagnosis']['bertscore'] = bertscore
    # 影像诊断任务总分
    imaging_diagnosis_scores = []
    imaging_diagnosis_scores.append(score_item[clinical_department]['metrics']['imaging_diagnosis']['bleu'])
    imaging_diagnosis_scores.append(score_item[clinical_department]['metrics']['imaging_diagnosis']['rouge'])
    imaging_diagnosis_scores.append(score_item[clinical_department]['metrics']['imaging_diagnosis']['bertscore'])
    imaging_diagnosis_score = round(float(np.mean(imaging_diagnosis_scores)), 2)
    score_item[clinical_department]['imaging_diagnosis'] = imaging_diagnosis_score

    # ====================================================
    # 计算所有诊断任务平均分
    # ====================================================
    diagnosis_task_list = clinical_diagnosis_part_list + ['imaging_diagnosis']
    score_item[clinical_department]['diagnosis_average'] = round(
        float(np.mean([score_item[clinical_department][diagnosis_task] for diagnosis_task in diagnosis_task_list])), 2)

    # ====================================================
    # 计算可接受率
    # ====================================================
    score_item[clinical_department]['acceptability'] = round(
        float((score_item[clinical_department]['comprehensive_diagnostic_accuracy'] * score_item[clinical_department][
            'diagnosis_average']) / 100), 2)

    print(score_item)

    return score_item


def main():
    gd_zh_load_path = data_dir / Path(gd_zh_load_name)
    with open(gd_zh_load_path, mode='r', encoding='utf-8') as file:
        gd_zh_dataset = json.load(file)
    # gd_en_load_path = data_dir / Path(gd_en_load_name)
    # with open(gd_en_load_path, mode='r', encoding='utf-8') as file:
    #     gd_en_dataset = json.load(file)
    gd_en_dataset = None

    score_dict = {}
    score_dict['code'] = 0
    score_dict['data'] = []
    for model_name in model_name_list:
        print(model_name)
        inference_load_name = f'inference_{language}_{model_name}.json'
        inference_load_path = inference_dir / Path(inference_load_name)
        with open(inference_load_path, mode='r', encoding='utf-8') as file:
            dataset = json.load(file)

        score_item = {}
        score_item['model'] = model_name_mapping_dict[model_name]
        score_item['institution'] = institution_name_mapping_dict[model_name]
        score_item['url'] = institution_url_mapping_dict[model_name]

        # 分科室比较效果
        if (language == 'zh'):
            clinical_department_list = clinical_department_zh_list
            gd_dataset = gd_zh_dataset
        else:
            clinical_department_list = clinical_department_en_list
            gd_dataset = gd_en_dataset

        clinical_cases_list = []
        for clinical_department in clinical_department_list:
            clinical_cases_list.append([item for item in dataset if item['clinical_department'] == clinical_department])
        for clinical_department, clinical_cases in zip(clinical_department_list, clinical_cases_list):
            print(f'{clinical_department}: {len(clinical_cases)}')
            score_item = calculate_metrics(score_item, clinical_department, clinical_cases, gd_dataset, language)
        score_item = calculate_metrics(score_item, 'overall', dataset, gd_dataset, language)

        score_dict['data'].append(score_item)

    score_dict['data'].sort(key=lambda x: x['overall']['acceptability'], reverse=True)
    print(score_dict)

    score_save_path = score_dir / Path(score_save_name)
    with open(str(score_save_path), mode='w', encoding='utf-8') as file:
        json.dump(score_dict, file, ensure_ascii=False, indent=2)


if __name__ == '__main__':
    sys.setrecursionlimit(3000)

    language = 'zh'
    gd_zh_load_name = 'disease_diagnosis_ground_truth.json'
    gd_en_load_name = ''
    score_save_name = 'score_acceptability(2024-05-27).json'

    model_name_list = [
        'baichuan2chat',
        'bianque2',
        'bluelmchat',
        'chatglm3',
        'claude3',
        'discmedllm',
        'geminipro',
        'gpt3.5',
        'gpt4',
        'huatuogpt2',
        'internlm2chat',
        'pulse',
        'qwenchat',
        'spark3',
        'taiyillm',
        'wingpt2',
        'yichat',
    ]

    clinical_diagnosis_part_list = [
        'diagnostic_basis',
        'differential_diagnosis',
        'therapeutic_principle',
        'treatment_plan'
    ]

    model_name_mapping_dict = {
        'baichuan2chat': 'Baichuan2-13B-Chat',  # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
        'bianque2': 'BianQue-2',  # https://huggingface.co/scutcyr/BianQue-2
        'bluelmchat': 'BlueLM-7B-Chat',  # https://huggingface.co/vivo-ai/BlueLM-7B-Chat
        'chatglm3': 'ChatGLM3-6B',  # https://huggingface.co/THUDM/chatglm3-6b
        'claude3': 'Claude-3',  # https://www.anthropic.com/news/claude-3-haiku
        'discmedllm': 'DISC-MedLLM',  # https://huggingface.co/Flmc/DISC-MedLLM
        'geminipro': 'Gemini-Pro',  # https://ai.google.dev/models/gemini
        'gpt3.5': 'GPT-3.5',  # https://platform.openai.com/docs/models/gpt-3-5
        'gpt4': 'GPT-4',  # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
        'huatuogpt2': 'HuatuoGPT2-34B',  # https://huggingface.co/FreedomIntelligence/HuatuoGPT2-34B
        'internlm2chat': 'InternLM2-20B-Chat',  # https://huggingface.co/internlm/internlm2-chat-20b
        'pulse': 'PULSE-20B',  # https://huggingface.co/OpenMEDLab/PULSE-20bv5
        'qwenchat': 'Qwen-72B-Chat',  # https://huggingface.co/Qwen/Qwen-72B-Chat
        'spark3': 'Spark-3',  # https://xinghuo.xfyun.cn/
        'taiyillm': 'Taiyi-LLM',  # https://huggingface.co/DUTIR-BioNLP/Taiyi-LLM
        'wingpt2': 'WiNGPT2-14B-Chat',  # https://huggingface.co/winninghealth/WiNGPT2-14B-Chat
        'yichat': 'Yi-34B-Chat',  # https://huggingface.co/01-ai/Yi-34B-Chat
    }

    institution_name_mapping_dict = {
        'baichuan2chat': 'Baichuan AI',
        'bianque2': 'SCUT-FT',
        'bluelmchat': 'Vivo',
        'chatglm3': 'THUDM & Zhipu AI',
        'claude3': 'Anthropic',
        'discmedllm': 'Fudan-DISC',
        'geminipro': 'Google',
        'gpt3.5': 'OpenAI',
        'gpt4': 'OpenAI',
        'huatuogpt2': 'CUHK-Shenzhen',
        'internlm2chat': 'Shanghai AI Laboratory',
        'pulse': 'Shanghai AI Laboratory',
        'qwenchat': 'Alibaba Cloud',
        'spark3': 'iFLYTEK',
        'taiyillm': 'DUTIR-BioNLP',
        'wingpt2': 'Winning Health',
        'yichat': '01 AI',
    }

    institution_url_mapping_dict = {
        'baichuan2chat': 'https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat',
        'bianque2': 'https://huggingface.co/scutcyr/BianQue-2',
        'bluelmchat': 'https://huggingface.co/vivo-ai/BlueLM-7B-Chat',
        'chatglm3': 'https://huggingface.co/THUDM/chatglm3-6b',
        'claude3': 'https://www.anthropic.com/news/claude-3-haiku',
        'discmedllm': 'https://huggingface.co/Flmc/DISC-MedLLM',
        'geminipro': 'https://ai.google.dev/models/gemini',
        'gpt3.5': 'https://platform.openai.com/docs/models/gpt-3-5',
        'gpt4': 'https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo',
        'huatuogpt2': 'https://huggingface.co/FreedomIntelligence/HuatuoGPT2-34B',
        'internlm2chat': 'https://huggingface.co/internlm/internlm2-chat-20b',
        'pulse': 'https://huggingface.co/OpenMEDLab/PULSE-20bv5',
        'qwenchat': 'https://huggingface.co/Qwen/Qwen-72B-Chat',
        'spark3': 'https://xinghuo.xfyun.cn/',
        'taiyillm': 'https://huggingface.co/DUTIR-BioNLP/Taiyi-LLM',
        'wingpt2': 'https://huggingface.co/winninghealth/WiNGPT2-14B-Chat',
        'yichat': 'https://huggingface.co/01-ai/Yi-34B-Chat',
    }

    clinical_department_zh_to_en_dict = {
        '乳腺外科': 'breast surgical department',
        '产科': 'obstetrics department',
        '儿科': 'pediatrics department',
        '内分泌内科': 'endocrinology department',
        '呼吸内科': 'respiratory medicine department',
        '妇科': 'gynecology department',
        '心脏外科': 'cardiac surgical department',
        '心血管内科': 'cardiovascular medicine department',
        '泌尿外科': 'urinary surgical department',
        '消化内科': 'gastroenterology department',
        '甲状腺外科': 'thyroid surgical department',
        '疝外科': 'hernia surgical department',
        '神经内科': 'neurology department',
        '神经外科': 'neurosurgery department',
        '耳鼻咽喉头颈外科': 'otolaryngology head and neck surgical department',
        '肛门结直肠外科': 'anus and intestine surgical department',
        '肝胆胰外科': 'hepatobiliary and pancreas surgical department',
        '肾内科': 'nephrology department',
        '胃肠外科': 'gastrointestinal surgical department',
        '胸外科': 'thoracic surgical department',
        '血液内科': 'hematology department',
        '血管外科': 'vascular surgical department',
        '骨科': 'orthopedics department',
    }
    clinical_department_zh_list = list(clinical_department_zh_to_en_dict.keys())
    clinical_department_en_list = list(clinical_department_zh_to_en_dict.values())

    inference_dir = Path(__file__).parent.parent / Path('inferences')
    if not inference_dir.is_dir():
        inference_dir.mkdir(parents=True, exist_ok=True)
    data_dir = Path(__file__).parent.parent / Path('data')
    if not data_dir.is_dir():
        data_dir.mkdir(parents=True, exist_ok=True)
    score_dir = Path(__file__).parent.parent / Path('scores')
    if not score_dir.is_dir():
        score_dir.mkdir(parents=True, exist_ok=True)

    main()
