import sys
import json

from pathlib import Path

def calculate_metrics(score_item, clinical_department, guide3_dataset, guide5_dataset):
    score_item[clinical_department] = {}

    # ====================================================
    # 计算导诊任务分数
    # ====================================================
    for k in [3, 5]:
        if k == 3:
            guide_dataset = guide3_dataset
        else:
            guide_dataset = guide5_dataset
        guidek_accurate_clinical_cases = 0
        guidek_total_clinical_cases = len(guide_dataset)
        for item in guide_dataset:
            predicted_clinical_department_list = item['predicted_clinical_department']
            ground_truth_clinical_department = item['clinical_department']
            for predicted_clinical_department in predicted_clinical_department_list:
                if predicted_clinical_department == ground_truth_clinical_department:
                    guidek_accurate_clinical_cases += 1
                    break
        # 导诊任务指标
        score_item[clinical_department][f'guide{k}_departmental_accuracy'] = round(((
                                                                                            guidek_accurate_clinical_cases / guidek_total_clinical_cases) * 100),
                                                                                   2)

    return score_item


def main():
    score_dict = {}
    score_dict['code'] = 0
    score_dict['data'] = []
    for model_name in model_name_list:
        print(model_name)
        guide3_load_name = f'guide_{language}_{model_name}_acc3.json'
        guide3_load_path = guide_dir / Path(guide3_load_name)
        with open(guide3_load_path, mode='r', encoding='utf-8') as file:
            guide3_dataset = json.load(file)

        guide5_load_name = f'guide_{language}_{model_name}_acc5.json'
        guide5_load_path = guide_dir / Path(guide5_load_name)
        with open(guide5_load_path, mode='r', encoding='utf-8') as file:
            guide5_dataset = json.load(file)

        score_item = {}
        score_item['model'] = model_name_mapping_dict[model_name]
        score_item['institution'] = institution_name_mapping_dict[model_name]
        score_item['url'] = institution_url_mapping_dict[model_name]

        # 分科室比较效果
        if (language == 'zh'):
            clinical_department_list = clinical_department_zh_list
        else:
            clinical_department_list = clinical_department_en_list

        guide3_clinical_cases_list = []
        guide5_clinical_cases_list = []
        for clinical_department in clinical_department_list:
            guide3_clinical_cases_list.append(
                [item for item in guide3_dataset if item['clinical_department'] == clinical_department])
            guide5_clinical_cases_list.append(
                [item for item in guide5_dataset if item['clinical_department'] == clinical_department])

        for clinical_department, guide3_clinical_cases, guide5_clinical_cases in zip(clinical_department_list,
                                                                                     guide3_clinical_cases_list,
                                                                                     guide5_clinical_cases_list):
            print(f'{clinical_department}: {len(guide3_clinical_cases)}')
            score_item = calculate_metrics(score_item, clinical_department,
                                           guide3_clinical_cases, guide5_clinical_cases)
        score_item = calculate_metrics(score_item, 'overall', guide3_dataset, guide5_dataset)
        print(score_item)

        score_dict['data'].append(score_item)

    score_save_path = score_dir / Path(score_save_name)
    with open(str(score_save_path), mode='w', encoding='utf-8') as file:
        json.dump(score_dict, file, ensure_ascii=False, indent=2)


if __name__ == '__main__':
    sys.setrecursionlimit(3000)

    language = 'zh'
    score_save_name = 'score_accuracy(2024-05-27).json'

    model_name_list = [
        'baichuan2chat',
        # 'bianque2',
        'bluelmchat',
        'chatglm3',
        'claude3',
        # 'discmedllm',
        'geminipro',
        'gpt3.5',
        'gpt4',
        'huatuogpt2',
        'internlm2chat',
        'pulse',
        'qwenchat',
        'spark3',
        'taiyillm',
        'wingpt2',
        'yichat',
    ]

    model_name_mapping_dict = {
        'baichuan2chat': 'Baichuan2-13B-Chat',  # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
        'bianque2': 'BianQue-2',  # https://huggingface.co/scutcyr/BianQue-2
        'bluelmchat': 'BlueLM-7B-Chat',  # https://huggingface.co/vivo-ai/BlueLM-7B-Chat
        'chatglm3': 'ChatGLM3-6B',  # https://huggingface.co/THUDM/chatglm3-6b
        'claude3': 'Claude-3',  # https://www.anthropic.com/news/claude-3-haiku
        'discmedllm': 'DISC-MedLLM',  # https://huggingface.co/Flmc/DISC-MedLLM
        'geminipro': 'Gemini-Pro',  # https://ai.google.dev/models/gemini
        'gpt3.5': 'GPT-3.5',  # https://platform.openai.com/docs/models/gpt-3-5
        'gpt4': 'GPT-4',  # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
        'huatuogpt2': 'HuatuoGPT2-34B',  # https://huggingface.co/FreedomIntelligence/HuatuoGPT2-34B
        'internlm2chat': 'InternLM2-20B-Chat',  # https://huggingface.co/internlm/internlm2-chat-20b
        'pulse': 'PULSE-20B',  # https://huggingface.co/OpenMEDLab/PULSE-20bv5
        'qwenchat': 'Qwen-72B-Chat',  # https://huggingface.co/Qwen/Qwen-72B-Chat
        'spark3': 'Spark-3',  # https://xinghuo.xfyun.cn/
        'taiyillm': 'Taiyi-LLM',  # https://huggingface.co/DUTIR-BioNLP/Taiyi-LLM
        'wingpt2': 'WiNGPT2-14B-Chat',  # https://huggingface.co/winninghealth/WiNGPT2-14B-Chat
        'yichat': 'Yi-34B-Chat',  # https://huggingface.co/01-ai/Yi-34B-Chat
    }

    institution_name_mapping_dict = {
        'baichuan2chat': 'Baichuan AI',
        'bianque2': 'SCUT-FT',
        'bluelmchat': 'Vivo',
        'chatglm3': 'THUDM & Zhipu AI',
        'claude3': 'Anthropic',
        'discmedllm': 'Fudan-DISC',
        'geminipro': 'Google',
        'gpt3.5': 'OpenAI',
        'gpt4': 'OpenAI',
        'huatuogpt2': 'CUHK-Shenzhen',
        'internlm2chat': 'Shanghai AI Laboratory',
        'pulse': 'Shanghai AI Laboratory',
        'qwenchat': 'Alibaba Cloud',
        'spark3': 'iFLYTEK',
        'taiyillm': 'DUTIR-BioNLP',
        'wingpt2': 'Winning Health',
        'yichat': '01 AI',
    }

    institution_url_mapping_dict = {
        'baichuan2chat': 'https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat',
        'bianque2': 'https://huggingface.co/scutcyr/BianQue-2',
        'bluelmchat': 'https://huggingface.co/vivo-ai/BlueLM-7B-Chat',
        'chatglm3': 'https://huggingface.co/THUDM/chatglm3-6b',
        'claude3': 'https://www.anthropic.com/news/claude-3-haiku',
        'discmedllm': 'https://huggingface.co/Flmc/DISC-MedLLM',
        'geminipro': 'https://ai.google.dev/models/gemini',
        'gpt3.5': 'https://platform.openai.com/docs/models/gpt-3-5',
        'gpt4': 'https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo',
        'huatuogpt2': 'https://huggingface.co/FreedomIntelligence/HuatuoGPT2-34B',
        'internlm2chat': 'https://huggingface.co/internlm/internlm2-chat-20b',
        'pulse': 'https://huggingface.co/OpenMEDLab/PULSE-20bv5',
        'qwenchat': 'https://huggingface.co/Qwen/Qwen-72B-Chat',
        'spark3': 'https://xinghuo.xfyun.cn/',
        'taiyillm': 'https://huggingface.co/DUTIR-BioNLP/Taiyi-LLM',
        'wingpt2': 'https://huggingface.co/winninghealth/WiNGPT2-14B-Chat',
        'yichat': 'https://huggingface.co/01-ai/Yi-34B-Chat',
    }

    clinical_department_zh_to_en_dict = {
        '乳腺外科': 'breast surgical department',
        '产科': 'obstetrics department',
        '儿科': 'pediatrics department',
        '内分泌内科': 'endocrinology department',
        '呼吸内科': 'respiratory medicine department',
        '妇科': 'gynecology department',
        '心脏外科': 'cardiac surgical department',
        '心血管内科': 'cardiovascular medicine department',
        '泌尿外科': 'urinary surgical department',
        '消化内科': 'gastroenterology department',
        '甲状腺外科': 'thyroid surgical department',
        '疝外科': 'hernia surgical department',
        '神经内科': 'neurology department',
        '神经外科': 'neurosurgery department',
        '耳鼻咽喉头颈外科': 'otolaryngology head and neck surgical department',
        '肛门结直肠外科': 'anus and intestine surgical department',
        '肝胆胰外科': 'hepatobiliary and pancreas surgical department',
        '肾内科': 'nephrology department',
        '胃肠外科': 'gastrointestinal surgical department',
        '胸外科': 'thoracic surgical department',
        '血液内科': 'hematology department',
        '血管外科': 'vascular surgical department',
        '骨科': 'orthopedics department',
    }
    clinical_department_zh_list = list(clinical_department_zh_to_en_dict.keys())
    clinical_department_en_list = list(clinical_department_zh_to_en_dict.values())

    guide_dir = Path(__file__).parent.parent / Path('guides')
    if not guide_dir.is_dir():
        guide_dir.mkdir(parents=True, exist_ok=True)
    score_dir = Path(__file__).parent.parent / Path('scores')
    if not score_dir.is_dir():
        score_dir.mkdir(parents=True, exist_ok=True)

    main()
