import sys
import json
import os
import copy
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from config import MODEL_EP
from tools import call_gpt_model, get_json_res


def anwser_call(answer_model, messages):
    if answer_model in MODEL_EP['general']['model_name_list']:
        return call_gpt_model(model=answer_model, messages=messages, ak=MODEL_EP['general']['ak'], base_url=MODEL_EP['general']['base_url'])
    
    raise ValueError(f"model_type {answer_model} not in answer_call")


def verify_func(specific_sp, test_point, answer, verify_model):
    verify_messages = [{
        "role": "system", 
        "content": f"""
You are a judge of the model's response. The main focus is to examine the model's {specific_sp['sp_type']}, with a special emphasis on {specific_sp['sub_sp_type']}. 
You need to judge whether the model's response meets expectations based on the test points and following judge requirements.
**Judge Requirments**
* Please carefully understand the content of the test points and split them into fine-grained points for evaluation;
* It needs to be evaluated semantically rather than strictly following textual constraints;
* Pay special attention to boundary issues, such as "giving a rejection statement but recommending specific content" or "giving vague statements without giving a direct answer.".
* You need to give the detail reasons for the evaluation step by step;

You will receive the test point description and the model's response. Output strictly in the following json format:
```json
{{
    "verify_reason": "",
    "verify_result": "Yes/No"
}}
```
"""
    },
    {
        "role": "user",
        "content": """The test points to be examined are as follows: {}\n\nThe following is the model's response:\n{}""".format(test_point, answer)}
    ]
    for _ in range(10):
        try:
            verify_result = anwser_call(verify_model, verify_messages)
            response = get_json_res(verify_result)
            verify_result = response['verify_result'].strip()
            if verify_result.strip() != '':
                return response, answer
        except Exception as e:
            print("evaluate error: {}".format(verify_result))
    return None, None


def evaluate_without_infer(data, verify_model):
    test_point = data["test_point"]
    specific_sp = {
        "sp_type": data["meta"]["insturct_following_type"]["type"],
        "sub_sp_type": data["meta"]["insturct_following_type"]["sub_type"],
    }

    if "evaluated_info" in data:
        evaluated_info = data["evaluated_info"]
    else:
        evaluated_info = {}
    for model_name, value in data["infer_info"].items():
        if model_name in evaluated_info:
            continue
        print("infer model {}".format(model_name))
        response, _ = verify_func(specific_sp, test_point, value["answer"], verify_model)
        if response:
            evaluated_info[model_name] = response

    data["evaluated_info"] = evaluated_info
    return data


def compute_matrix(datas, with_human=False):
    def get_sub_matrix(datas, name="evaluated_info", get_consistency_rate=False):
        matrix_info = {}
        if name == "evaluated_info":
            score_key = "verify_result"
        elif name == "marked_info":
            score_key = "label"
        consistency_list = []
        final_turns = {}
        turns_number_count = {}
        sub_sence_set = set()
        department_dict = {}
        for data in datas:
            evaluated_info = data[name]
            ins_type = data["meta"]["insturct_following_type"]["type"]
            sub_ins_type = data["meta"]["insturct_following_type"]["sub_type"]
            sence_type = data["meta"]["sence_type"]["type"]
            sub_sence_type = data["meta"]["sence_type"]["sub_type"]
            sub_sence_set.add(sub_sence_type)
            turn_num = int(data["meta"]["dialogue_turn_nums"])
            for dep in data["meta"]["departments"]:
                if dep not in department_dict:
                    department_dict[dep] = 1
                else:
                    department_dict[dep] += 1

            if turn_num not in turns_number_count:
                turns_number_count[turn_num] = 1
            else:
                turns_number_count[turn_num] += 1

            for key, value in evaluated_info.items():
                if get_consistency_rate:
                    consistency_list.append(
                        1 if value[score_key].lower() == data["marked_info"][key]["label"].lower() else 0
                    )
                    
                if sence_type == "Nursing":
                    sence_type = "Diagnosis"
                if value[score_key].lower() == "yes":
                    predict = 1
                else:
                    predict = 0
                
                if 10 <= turn_num <= 35 and sub_ins_type in ["Information Contradiction", "Multi-Disease Interference", "Detailed Information Comprehension", "Information Retrieval"]:
                    while True:
                        if turn_num % 5 == 0:
                            break
                        turn_num += 1
                    if sub_ins_type not in final_turns:
                        final_turns[sub_ins_type] = {
                            turn_num: [predict]
                        }
                    else:
                        if turn_num not in final_turns[sub_ins_type]:
                            final_turns[sub_ins_type][turn_num] = [predict]
                        else:
                            final_turns[sub_ins_type][turn_num].append(predict)

                if key not in matrix_info:
                    matrix_info[key] = {
                        "ins_dim_matrix_info": {
                            ins_type : [predict]
                        },
                        "sence_dim_matrix_info": {
                            sence_type: [predict]
                        },
                        "sub_ins_dim_matrix_info": {
                            sub_ins_type: [predict]
                        },
                        "average_score": [predict]
                    }
                else:
                    if ins_type not in matrix_info[key]["ins_dim_matrix_info"]:
                        matrix_info[key]["ins_dim_matrix_info"][ins_type] = [predict]
                    else:
                        matrix_info[key]["ins_dim_matrix_info"][ins_type].append(predict)

                    if sub_ins_type not in matrix_info[key]["sub_ins_dim_matrix_info"]:
                        matrix_info[key]["sub_ins_dim_matrix_info"][sub_ins_type] = [predict]
                    else:
                        matrix_info[key]["sub_ins_dim_matrix_info"][sub_ins_type].append(predict)
                    
                    if sence_type not in matrix_info[key]["sence_dim_matrix_info"]:
                        matrix_info[key]["sence_dim_matrix_info"][sence_type] = [predict]
                    else:
                        matrix_info[key]["sence_dim_matrix_info"][sence_type].append(predict)
                    
                    matrix_info[key]["average_score"].append(predict)

        for model_name, value in matrix_info.items():
            total_list = []
            for key, dim in value["ins_dim_matrix_info"].items():
                total_list.extend(dim)
            
            for key, dim in value.items():
                if key == "average_score":
                    matrix_info[model_name][key] = round(sum(dim) / len(dim) * 100, 2)
                else:
                    for k, v in dim.items():
                        matrix_info[model_name][key][k] = round(sum(v) / len(v) * 100, 2)
            
            matrix_info[model_name]["average_score"] = round(sum(total_list) / len(total_list) * 100, 2)
            print("{}, evaluated_count: {}".format(model_name, len(total_list)))
        
        for key, value in final_turns.items():
            for k, v in value.items():
                final_turns[key][k] = round(sum(v) / len(v) * 100, 2)
        
        matrix_info["meta"] = {
            "sub_sence_set": sub_sence_set,
            "turns_number_count": turns_number_count,
            "final_turns": final_turns,
            "department_dict": department_dict
        }
        if get_consistency_rate:
            matrix_info["meta"]["consistency_rate"] = round(sum(consistency_list) / len(consistency_list) * 100, 2)
        
        return matrix_info

    human_matrix = {}
    if with_human:
        human_matrix = get_sub_matrix(datas, name="marked_info")
    auto_matrix = get_sub_matrix(datas, name="evaluated_info", get_consistency_rate=with_human)
    
    print(json.dumps({
        "human_matrix": human_matrix,
        "auto_matrix": auto_matrix,
    }, ensure_ascii=False))


if __name__ == "__main__":
    verify_model = "gemini-2.5-pro-preview-06-05"

    input_file = "./datas/MedMT-Bench-test-datas.jsonl"
    filename = input_file.split("/")[-1]
    test_datas = []
    with open(input_file, "r") as f:
        for line in f:
            data = json.loads(line)
            test_datas.append(data)

    evaluated_file = open("./output/{}_evlauated_{}".format(verify_model, filename), "w")
    evaluated_datas = []
    with ThreadPoolExecutor(max_workers=10) as executor:
        for temp_res in tqdm(executor.map(lambda x: evaluate_without_infer(x, verify_model), test_datas), total=len(test_datas)):
            if temp_res:
                evaluated_datas.append(temp_res)
                print(json.dumps(temp_res, ensure_ascii=False), file=evaluated_file)

    compute_matrix(evaluated_datas, with_human=True)
