# benchmark for fake player
import sys
import os
import json
from concurrent.futures import ThreadPoolExecutor
import threading
from sentence_transformers import SentenceTransformer
from utils.util import get_prompt
from utils.api import request_api, request_api_notopp
import pandas as pd
from json_repair import repair_json
import copy
import numpy as np
from tqdm import tqdm
write_lock = threading.Lock()


def safe_entropy(prob_dist):
    """安全计算信息熵（基数为2）"""
    # 1. 过滤非正值并重新归一化
    prob_dist = np.maximum(prob_dist, 0)  # 确保所有值≥0
    prob_dist = prob_dist / (prob_dist.sum() + 1e-12)  # 防止除零
    
    # 2. 仅计算非零概率的熵（避免log(0)）
    non_zero_probs = prob_dist[prob_dist > 1e-12]
    if len(non_zero_probs) == 0:
        return 0.0
    
    # 3. 计算熵
    entropy = -np.sum(non_zero_probs * np.log2(non_zero_probs))
    return entropy

def judge(message, **kwargs):
    # response = request_api(message, model='deepseek-r1-250528')
    response = request_api_notopp(message, temperature=0)
    content = response.content
    content = content.replace('```', '').replace('json', '').replace('\n', '')
    try:
        json_data = json.loads(content)
    except json.JSONDecodeError:
            # 如果解析失败，尝试修复JSON
            repaired_content = repair_json(content)
            json_data = json.loads(repaired_content)
    # 提取指定字段
    result = {}
    for key in kwargs.keys():
        if key in json_data:
            result[key] = json_data[key]
        else:
            # 如果字段不存在，返回None
            result[key] = None
    return result

def process_judge(df, index, message, fw, **kwargs):
    result = judge(message, **kwargs)
    d = df.iloc[index].to_dict()
    
    d.update(result)
    with write_lock:
        fw.write(json.dumps(d, ensure_ascii=False) + '\n')
        fw.flush()



def eval_topic(df, save_dir, eval_prompt_file='/data/workspace/fake-player-pro/benckmark/prompts/topic_eval.txt'):
    save_eval = os.path.join(save_dir, 'topic_eval.jsonl')

    fw = open(save_eval, 'w')
    meta_prompt = get_prompt(eval_prompt_file)
    messages_list = []
    for index, row in tqdm(df.iterrows()):
        d, t, c = row['domain'],  row['topic'], row['context']

        prompt = meta_prompt.replace('{{domain}}', str(d)).replace('{{topic}}', str(t)).replace('{{context}}', str(c))
        message = [{"role":"system", "content": prompt}]
        messages_list.append(message)
    
    with ThreadPoolExecutor(max_workers=50) as executor:
        # json_dict
        futures = [executor.submit(process_judge, df, i, messages, fw, player_topic_depth=True, player_topic_width=True,  explanation=True) 
                for i, messages in enumerate(messages_list)]
    fw.close()


    # post eval
    # fi = open(save_eval, 'r')
    df_eval = pd.read_json(save_eval, lines=True)
    df_score = pd.DataFrame(df_eval)
    print(len(df_score))
    # df_eval['player_topic_compliance'].astype(float)
    # df_eval['player_domain_compliance'].astype(float)
    for index, row in df_eval.iterrows():
        domain = row['domain']
        domain_type = domain['对话类型']
        topic_width = row['player_topic_width']
        topic_depth = row['player_topic_depth']
        if domain_type == '闲聊类':
            
            if topic_width == 'A':
                df_score.loc[index, 'topic_width_score'] = 1
            elif topic_width == 'B':
                df_score.loc[index, 'topic_width_score'] = 0.5

            if topic_depth == 'A':
                df_score.loc[index, 'topic_depth_score'] = 0.5
            elif topic_depth == 'B':
                df_score.loc[index, 'topic_depth_score'] = 1
        elif domain_type == '知识类':
            if topic_width == 'A':
                df_score.loc[index, 'topic_width_score'] = 0.5
            elif topic_width == 'B':
                df_score.loc[index, 'topic_width_score'] = 1

            if topic_depth == 'A':
                df_score.loc[index, 'topic_depth_score'] = 1
            elif topic_depth == 'B':
                df_score.loc[index, 'topic_depth_score'] = 0.5
        
        combined_score = (df_score.loc[index, 'topic_width_score'] + df_score.loc[index, 'topic_depth_score']) / 2
        df_score.loc[index, 'topic_total_score'] = combined_score
        
    mean_topic_width = df_score['topic_width_score'].mean()
    mean_topic_depth = df_score['topic_depth_score'].mean()
    
    print(f'mean_topic_width: {mean_topic_width}, mean_topic_depth: {mean_topic_depth}')
    save_score = os.path.join(save_dir, 'topic_score.jsonl')
    df_score.to_json(save_score, orient='records', lines=True, force_ascii=False)
    # cal avg total score
    avg_total_score = df_score['topic_total_score'].mean()
    return avg_total_score


def eval_human_like(df, save_dir, eval_prompt_file='/data/workspace/fake-player-pro/benckmark/prompts/human_like_eval2.txt'):
    save_eval = os.path.join(save_dir, 'human_like_eval.jsonl')
    fw = open(save_eval, 'w')
    meta_prompt = get_prompt(eval_prompt_file)
    messages_list = []
    for index, row in df.iterrows():
        c = row['context']

        prompt = meta_prompt.replace('{{context}}', str(c))
        message = [{"role":"system", "content": prompt}]
        messages_list.append(message)
        
    with ThreadPoolExecutor(max_workers=50) as executor:
        # json_dict
        futures = [executor.submit(process_judge, df, i, messages, fw, clean=True, free=True,  explanation=True)
        # futures = [executor.submit(process_judge, df, i, messages, fw, multi_turn_player_interaction=True, logical_self_consistency=True, clean=True, emotional_free=True,  explanation=True) 
                    for i, messages in enumerate(messages_list)]
    fw.close()
    # post eval
    fi = open(save_eval, 'r')
    df_eval = pd.read_json(fi, lines=True)
    df_score = pd.DataFrame(df_eval)
    df_eval['clean'].astype(float)
    df_eval['free'].astype(float)
    for index, row in df_eval.iterrows():
        clean = float(row['clean'])
        free = float(row['free'])
        
        human_like_total_score = clean + free
        df_score.loc[index, 'human_like_total_score'] = human_like_total_score
    
    mean_clean = df_eval['clean'].mean()
    mean_free = df_eval['free'].mean()
    print(f'mean_clean: {mean_clean}, mean_free: {mean_free}')
    save_score = os.path.join(save_dir, 'human_like_score.jsonl')
    df_score.to_json(save_score, orient='records', lines=True, force_ascii=False)
    # cal avg total score
    avg_total_score = df_score['human_like_total_score'].mean()
    return avg_total_score


def eval_diversity(df:pd.DataFrame, save_dir, embedding_model='/data/models/Qwen3-Embedding-0.6B'):
    model = SentenceTransformer(embedding_model, device="cuda")
    
    # 预处理分组字段：确保为字符串类型
    def preprocess_group_fields(row):
        try:
            # 自定义函数将复杂结构转换为标准化字符串
            def to_standard_str(value):
                if pd.isna(value):
                    return ""
                if isinstance(value, (dict, list)):
                    # 使用json.dumps保证相同内容生成相同字符串，排序确保键的顺序不影响结果
                    return json.dumps(value, sort_keys=True, ensure_ascii=False)
                return str(value)
            
            return (
                to_standard_str(row['persona']),
                to_standard_str(row['domain']),
                to_standard_str(row['topic'])
            )
        except Exception as e:
            print(f"预处理分组字段失败: {e}")
            return ("", "", "", "")
        
    df[['persona_str', 'domain_str', 'topic_str']] = df.apply(
        preprocess_group_fields, axis=1, result_type='expand'
    )
    # 归类 根据persona,domain,scene,topic都一样的为一类去计算
    df_group = df.groupby(['persona_str', 'domain_str','topic_str'])
    
    
     
    def inter_session_similarity(group_df):
        common_context_turn = min([len(context) for context in group_df["context"]])
        all_scores = []
        
        for turn in range(common_context_turn):
            sessions = [str(context[turn]['player']) for context in group_df["context"]]
            embeddings = model.encode(sessions, batch_size=8)
            similarity_matrix = np.inner(embeddings, embeddings)
            np.fill_diagonal(similarity_matrix, 0)
            
            # 找到最大相似度的一对会话
            max_sim = np.max(similarity_matrix)
            
            # 构建二元概率分布 [max_sim, 1 - max_sim]
            prob_dist = np.array([max_sim, 1 - max_sim])
            prob_dist /= prob_dist.sum()  # 确保归一化
            
            # 计算信息熵（基数为2，范围[0,1]）
            entropy = safe_entropy(prob_dist)  # 避免log(0)
            
            # 反转映射：熵越高（多样性高）→得分越高，熵越低→得分越低
            diversity_score = 10 * entropy  # 因二元分布最大熵为1，直接乘以10
            
            all_scores.append(diversity_score)
        
        return np.mean(all_scores) if all_scores else 10.0  # 默认无数据时返回满分
    # 对每个分组应用计算
    results = []
    for group_name, group_df in df_group:
        group_df = group_df.copy()
        group_df["inter_score"] = inter_session_similarity(group_df)  # 整个分组的会话间得分
        results.append(group_df)
    
    # 合并所有分组结果
    final_df = pd.concat(results)
    # 求均值
    avg_inter_score = final_df['inter_score'].mean()
    final_df.to_json(os.path.join(save_dir, 'diversity_score.jsonl'), orient='records', lines=True, force_ascii=False)
    return avg_inter_score


def check_context(df):
    
    for index, row in df.iterrows():
        
        context = row['context']
        # print(f'round: {round}')
        for j, round in enumerate(context):
            try:
                if 'npc' not in round.keys():
                    print(f'error index: {index}')
                    print(f'error round: {round}')
                if 'player' not in round.keys():
                    print(f'error index: {index}')
                    print(f'error round: {round}')
            except Exception as e:
                return False
    return True
                
def main(eval_file, prompt_dir, embedding_model,  saves_dir):
    assert eval_file.endswith('.jsonl')
    assert os.path.exists(embedding_model)
    topic_eval_prompt_file = os.path.join(prompt_dir, 'topic_eval.txt')
    human_like_eval_prompt_file = os.path.join(prompt_dir, 'human_like_eval.txt')
    
    
    df = pd.read_json(eval_file, lines=True)
    safety = check_context(df)
    if not safety:
        print(f'error eval file: {eval_file}')
        return
    print(f' eval file: {eval_file}')


    topic_score = eval_topic(df, saves_dir, topic_eval_prompt_file)
    print(f'topic score: {topic_score}')
    
    human_like_score = eval_human_like(df, saves_dir, human_like_eval_prompt_file)
    print(f'human like score: {human_like_score}')

    avg_inter_score = eval_diversity(df, save_dir, embedding_model)
    print(f'avg inter score: {avg_inter_score}')
    
if __name__ == "__main__":


    eval_file = './baselines/rpa/doubao_saves/doubao.jsonl'
    save_dir = './baselines/rpa/doubao_saves'
    prompt_dir = './prompts'
    embedding_model = '/data/models/Qwen3-Embedding-0.6B'
    main(eval_file, prompt_dir, embedding_model, save_dir)
