#!/usr/bin/env python3
"""
Reward Model Scoring Script
为每个reward模型创建评分函数，从generated_answers.json中提取QA对并进行评分

特殊处理说明：
- QRM-Gemma-2-27B模型使用官方推荐的加载和评分方式
- Skywork-Reward-Gemma-2-27B模型使用官方推荐的加载和评分方式
- 其他模型使用通用的评分逻辑
"""

import os
import json
import torch
import logging
from typing import Dict, List, Any, Tuple
from datetime import datetime
from pathlib import Path
import traceback

# 获取当前脚本所在目录
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

# 确保日志目录存在
os.makedirs(SCRIPT_DIR, exist_ok=True)

# 设置logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(SCRIPT_DIR, 'scoring.log')),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class RewardModelScorer:
    """Reward模型评分器"""
    
    def __init__(self):
        self.model_base_path = "/mnt/data/model"
        self.qa_file_path = "/root/gMad/2_qa_generator/generated_answers_1.json"
        self.output_dir = os.path.join(SCRIPT_DIR, "results_1")
        os.makedirs(self.output_dir, exist_ok=True)
        
        # 获取所有可用的reward模型
        self.available_models = self._get_available_models()
        logger.info(f"发现 {len(self.available_models)} 个reward模型: {self.available_models}")
        
    def _get_available_models(self) -> List[str]:
        """获取所有可用的reward模型"""
        models = []
        if os.path.exists(self.model_base_path):
            for item in os.listdir(self.model_base_path):
                model_path = os.path.join(self.model_base_path, item)
                if os.path.isdir(model_path):
                    models.append(item)
        return sorted(models)
    
    def load_qa_data(self) -> List[Dict[str, Any]]:
        """从generated_answers.json加载QA数据"""
        logger.info(f"正在加载QA数据: {self.qa_file_path}")
        try:
            with open(self.qa_file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            logger.info(f"成功加载 {len(data)} 个QA对")
            return data
        except Exception as e:
            logger.error(f"加载QA数据失败: {e}")
            return []
    
    def extract_qa_pairs(self, qa_data: List[Dict[str, Any]]) -> List[Tuple[str, str]]:
        """从QA数据中提取问题和答案对"""
        qa_pairs = []
        
        for item in qa_data:
            question = item.get('question', '')
            
            # 处理不同类型的答案格式
            if 'answers' in item and isinstance(item['answers'], dict):
                # 如果有多个模型的答案，我们可以选择其中一个或者全部评分
                # 这里我们为每个模型答案都创建一个QA对
                for model_name, answer in item['answers'].items():
                    if answer and isinstance(answer, str):
                        qa_pairs.append((question, answer))
            elif 'answer' in item:
                # 如果只有一个标准答案
                answer = item['answer']
                if isinstance(answer, (str, int, float)):
                    qa_pairs.append((question, str(answer)))
        
        logger.info(f"提取了 {len(qa_pairs)} 个QA对")
        return qa_pairs
    
    def _fix_tokenizer(self, tokenizer):
        """确保tokenizer的特殊token是字符串以防止模板错误。"""
        
        # 定义需要检查和修复的特殊token及其默认值
        special_tokens_map = {
            "pad_token": "</s>",
            "eos_token": "</s>",
            "bos_token": "<s>",
            "unk_token": "<unk>",
        }

        for token_name, default_value in special_tokens_map.items():
            # 获取token的当前值
            token_value = getattr(tokenizer, token_name, None)
            
            # 如果token是AddedToken对象，则提取其字符串内容
            if hasattr(token_value, 'content'):
                token_value = token_value.content
            
            # 如果token不是字符串或为None，则进行修复
            if not isinstance(token_value, str):
                logger.warning(f"Tokenizer的 '{token_name}' 不是字符串 (类型: {type(token_value)})，正在修复。")
                
                # 如果token值为None，直接使用默认值
                if token_value is None:
                    setattr(tokenizer, token_name, default_value)
                    logger.info(f"已将 '{token_name}' 设置为默认值: '{default_value}'")
                else:
                    # 否则，尝试将其强制转换为字符串
                    try:
                        str_value = str(token_value)
                        setattr(tokenizer, token_name, str_value)
                        logger.info(f"已将 '{token_name}' 强制转换为字符串: '{str_value}'")
                    except Exception as e:
                        logger.error(f"无法将 '{token_name}' 转换为字符串: {e}，将使用默认值。")
                        setattr(tokenizer, token_name, default_value)

        # 确保pad_token最终被正确设置
        if tokenizer.pad_token is None:
            if tokenizer.eos_token and isinstance(tokenizer.eos_token, str):
                tokenizer.pad_token = tokenizer.eos_token
            else:
                tokenizer.pad_token = special_tokens_map["pad_token"]
        
        return tokenizer
    
    def load_model_and_tokenizer(self, model_name: str):
        """加载指定的reward模型和tokenizer"""
        model_path = os.path.join(self.model_base_path, model_name)
        logger.info(f"正在加载模型: {model_name}")
        
        try:
            # 导入必要的库
            from transformers import AutoTokenizer, AutoModelForSequenceClassification
            
            # 检查是否为特殊模型，使用专门的加载方式
            if "QRM-Gemma-2-27B" in model_name:
                logger.info("检测到QRM-Gemma-2-27B模型，使用官方推荐的加载方式")
                
                # 按照官方文档加载tokenizer
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
                
                # 按照官方文档加载模型（不使用flash attention）
                model = AutoModelForSequenceClassification.from_pretrained(
                    model_path, 
                    torch_dtype=torch.bfloat16,  # 官方推荐使用bfloat16
                    attn_implementation="eager",  # 使用eager attention代替flash attention
                    device_map="cuda" if torch.cuda.is_available() else "cpu",
                    trust_remote_code=True
                )
                
                logger.info(f"QRM模型加载成功，设备: {next(model.parameters()).device}")
                logger.info(f"QRM模型数据类型: {next(model.parameters()).dtype}")
                
                return model, tokenizer
            
            # 检查是否为ArmoRM模型，使用官方推荐的加载方式
            elif "ArmoRM" in model_name:
                logger.info(f"检测到ArmoRM模型: {model_name}，使用官方推荐的加载方式")
                
                # 按照官方文档加载model和tokenizer
                model = AutoModelForSequenceClassification.from_pretrained(
                    model_path, 
                    device_map="cuda" if torch.cuda.is_available() else "cpu",
                    trust_remote_code=True, 
                    torch_dtype=torch.bfloat16
                )
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
                
                logger.info(f"ArmoRM模型加载成功，设备: {next(model.parameters()).device}")
                logger.info(f"ArmoRM模型数据类型: {next(model.parameters()).dtype}")
                
                return model, tokenizer
            
            # 检查是否为Skywork-Reward-V2系列模型，使用官方推荐的加载方式
            elif "Skywork-Reward" in model_name:
                logger.info(f"检测到Skywork-Reward模型: {model_name}，使用官方推荐的加载方式")
                
                # 按照官方文档加载tokenizer
                tokenizer = AutoTokenizer.from_pretrained(model_path)
                
                # 按照官方文档加载模型，尝试使用flash_attention_2
                try:
                    model = AutoModelForSequenceClassification.from_pretrained(
                        model_path,
                        torch_dtype=torch.bfloat16,
                        device_map="cuda:0" if torch.cuda.is_available() else "cpu",
                        attn_implementation="flash_attention_2",
                        num_labels=1,
                        trust_remote_code=True
                    )
                    logger.info("Skywork模型加载成功 (使用flash_attention_2)")
                except Exception as e:
                    logger.info(f"flash_attention_2加载失败: {e}，尝试使用eager模式...")
                    model = AutoModelForSequenceClassification.from_pretrained(
                        model_path,
                        torch_dtype=torch.bfloat16,
                        device_map="cuda:0" if torch.cuda.is_available() else "cpu",
                        attn_implementation="eager",
                        num_labels=1,
                        trust_remote_code=True
                    )
                    logger.info("Skywork模型加载成功 (使用eager)")
                
                logger.info(f"Skywork模型加载成功，设备: {next(model.parameters()).device}")
                logger.info(f"Skywork模型数据类型: {next(model.parameters()).dtype}")
                
                return model, tokenizer
            
            # 对于其他模型，使用原有的加载逻辑
            tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            tokenizer = self._fix_tokenizer(tokenizer) # 在此修复tokenizer

            # 为Gemma等大型模型选择合适的dtype
            model_dtype = torch.float16
            if torch.cuda.is_available():
                if "gemma" in model_name.lower() and hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_bf16_supported():
                    logger.info(f"检测到Gemma模型且硬件支持bfloat16，将使用torch.bfloat16加载模型 {model_name}")
                    model_dtype = torch.bfloat16
                else:
                    logger.info(f"将使用torch.float16加载模型 {model_name}")
            else:
                model_dtype = torch.float32

            model = AutoModelForSequenceClassification.from_pretrained(
                model_path, 
                trust_remote_code=True,
                torch_dtype=model_dtype,
                device_map="auto" if torch.cuda.is_available() else None
            )
            
            # 如果tokenizer没有pad_token，设置为eos_token (旧逻辑，已移至_fix_tokenizer)
            
            logger.info(f"成功加载模型: {model_name}")
            return model, tokenizer
            
        except Exception as e:
            logger.error(f"加载模型 {model_name} 失败: {e}")
            logger.error(traceback.format_exc())
            return None, None
    
    def score_qa_pair_armorm(self, question: str, answer: str, model, tokenizer, device: str) -> float:
        """专门为ArmoRM模型设计的评分函数，遵循官方文档"""
        try:
            # 按照官方文档构造消息格式
            messages = [
                {"role": "user", "content": question},
                {"role": "assistant", "content": answer}
            ]
            
            # 使用官方推荐的方法处理输入
            input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
            
            # 进行推理
            with torch.no_grad():
                output = model(input_ids)
                
                # ArmoRM模型有特殊的输出格式
                if hasattr(output, 'score'):
                    # 官方文档：preference score for the response, aggregated from the multi-objective rewards
                    preference_score = output.score.cpu().float()
                    score_value = float(preference_score.item())
                    return score_value
                elif hasattr(output, 'logits'):
                    # 回退到logits处理
                    logits = output.logits
                    if logits.size(-1) > 1:
                        score = torch.softmax(logits, dim=-1).max().item()
                    else:
                        score = logits.item()
                    return float(score)
                else:
                    # 详细记录输出属性用于调试
                    available_attrs = [attr for attr in dir(output) if not attr.startswith('_')]
                    logger.warning(f"ArmoRM模型输出中没有找到'score'或'logits'属性，可用属性: {available_attrs}")
                    return 0.0
                    
        except Exception as e:
            logger.error(f"ArmoRM模型评分过程出错: {e}")
            logger.error(traceback.format_exc())
            return 0.0

    def score_qa_pair_skywork(self, question: str, answer: str, model, tokenizer, device: str) -> float:
        """专门为Skywork-Reward系列模型设计的评分函数，遵循官方文档"""
        try:
            # 按照官方文档构造消息格式
            messages = [
                {"role": "user", "content": question},
                {"role": "assistant", "content": answer}
            ]
            
            # 格式化对话
            conv_formatted = tokenizer.apply_chat_template(messages, tokenize=False)
            
            # 按照官方文档移除可能的重复BOS token
            if tokenizer.bos_token is not None and conv_formatted.startswith(tokenizer.bos_token):
                conv_formatted = conv_formatted[len(tokenizer.bos_token):]
            
            # tokenize对话
            conv_tokenized = tokenizer(conv_formatted, return_tensors="pt").to(device)
            
            # 获取奖励分数
            with torch.no_grad():
                output = model(**conv_tokenized)
                
                # Skywork模型使用logits输出
                if hasattr(output, 'logits'):
                    score = output.logits[0][0].item()
                    return float(score)
                else:
                    # 详细记录输出属性用于调试
                    available_attrs = [attr for attr in dir(output) if not attr.startswith('_')]
                    logger.warning(f"Skywork模型输出中没有找到'logits'属性，可用属性: {available_attrs}")
                    return 0.0
                    
        except Exception as e:
            logger.error(f"Skywork模型评分过程出错: {e}")
            logger.error(traceback.format_exc())
            return 0.0

    def score_qa_pair_qrm_gemma(self, question: str, answer: str, model, tokenizer, device: str) -> float:
        """专门为QRM-Gemma-2-27B模型设计的评分函数"""
        try:
            # 按照官方文档构造消息格式
            messages = [
                {"role": "user", "content": question},
                {"role": "assistant", "content": answer}
            ]
            
            # 使用官方推荐的方法处理输入
            input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
            
            # 进行推理
            with torch.no_grad():
                output = model(input_ids)
                
                # QRM模型有特殊的输出格式
                if hasattr(output, 'score'):
                    # 官方文档：Expectation of the reward distribution
                    reward = output.score.cpu().float()
                    score_value = float(reward.item())
                    # 可选：记录reward_quantiles信息用于调试
                    if hasattr(output, 'reward_quantiles'):
                        quantiles = output.reward_quantiles.cpu().float()
                        # 只在第一次评分时显示quantiles信息
                        if not hasattr(self, '_qrm_quantiles_logged'):
                            logger.info(f"QRM模型奖励分位数形状: {quantiles.shape}")
                            self._qrm_quantiles_logged = True
                    return score_value
                elif hasattr(output, 'logits'):
                    # 回退到logits处理
                    logits = output.logits
                    if logits.size(-1) > 1:
                        score = torch.softmax(logits, dim=-1).max().item()
                    else:
                        score = logits.item()
                    return float(score)
                else:
                    # 详细记录输出属性用于调试
                    available_attrs = [attr for attr in dir(output) if not attr.startswith('_')]
                    logger.warning(f"QRM模型输出中没有找到'score'或'logits'属性，可用属性: {available_attrs}")
                    return 0.0
                    
        except Exception as e:
            logger.error(f"QRM模型评分过程出错: {e}")
            logger.error(traceback.format_exc())
            return 0.0

    def score_qa_pair(self, question: str, answer: str, model, tokenizer) -> float:
        """使用reward模型对单个QA对进行评分"""
        try:
            # 构造输入文本，不同的reward模型可能有不同的格式要求
            # 检查tokenizer是否有chat_template并使用
            if tokenizer.chat_template:
                messages = [
                    {"role": "user", "content": question},
                    {"role": "assistant", "content": answer}
                ]
                input_text = tokenizer.apply_chat_template(
                    messages, 
                    tokenize=False, 
                    add_generation_prompt=False
                )
            else:
                # 回退到通用格式
                input_text = f"Question: {question}\nAnswer: {answer}"
            
            # 编码输入
            inputs = tokenizer(
                input_text,
                return_tensors="pt",
                truncation=True,
                max_length=512,
                padding=True
            )
            
            # 移动到模型所在设备
            if hasattr(model, 'device'):
                inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            # 获取评分
            with torch.no_grad():
                outputs = model(**inputs)
                
                # 不同的reward模型输出格式可能不同
                if hasattr(outputs, 'logits'):
                    logits = outputs.logits
                    # 如果是分类模型，取最高分
                    if logits.size(-1) > 1:
                        score = torch.softmax(logits, dim=-1).max().item()
                    else:
                        # 如果是回归模型，直接使用输出值
                        score = logits.item()
                else:
                    score = 0.0
                    
            return float(score)
            
        except Exception as e:
            logger.error(f"评分过程出错: {e}")
            return 0.0
    
    def score_with_model(self, model_name: str, qa_pairs: List[Tuple[str, str]]) -> Dict[str, Any]:
        """使用指定模型对所有QA对进行评分"""
        logger.info(f"开始使用模型 {model_name} 进行评分...")
        
        # 加载模型
        model, tokenizer = self.load_model_and_tokenizer(model_name)
        if model is None or tokenizer is None:
            logger.error(f"无法加载模型 {model_name}，跳过评分")
            return {"model_name": model_name, "status": "failed", "scores": []}
        
        scores = []
        total_pairs = len(qa_pairs)
        
        # 检查模型类型，决定使用哪个评分函数
        is_qrm_model = "QRM-Gemma-2-27B" in model_name
        is_skywork_model = "Skywork-Reward" in model_name
        is_armorm_model = "ArmoRM" in model_name
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        
        if is_qrm_model:
            logger.info(f"使用QRM专用评分函数处理模型 {model_name}")
        elif is_skywork_model:
            logger.info(f"使用Skywork专用评分函数处理模型 {model_name}")
        elif is_armorm_model:
            logger.info(f"使用ArmoRM专用评分函数处理模型 {model_name}")
        else:
            logger.info(f"使用通用评分函数处理模型 {model_name}")
        
        try:
            for i, (question, answer) in enumerate(qa_pairs):
                if i % 100 == 0:  # 每100个输出一次进度
                    logger.info(f"评分进度: {i}/{total_pairs} ({i/total_pairs*100:.1f}%)")
                
                # 根据模型类型选择评分函数
                if is_qrm_model:
                    score = self.score_qa_pair_qrm_gemma(question, answer, model, tokenizer, device)
                elif is_skywork_model:
                    score = self.score_qa_pair_skywork(question, answer, model, tokenizer, device)
                elif is_armorm_model:
                    score = self.score_qa_pair_armorm(question, answer, model, tokenizer, device)
                else:
                    score = self.score_qa_pair(question, answer, model, tokenizer)
                
                scores.append({
                    "question": question,
                    "answer": answer,
                    "score": score,
                    "index": i
                })
            
            logger.info(f"模型 {model_name} 评分完成，共评分 {len(scores)} 个QA对")
            
            # 计算统计信息
            score_values = [item["score"] for item in scores]
            stats = {
                "mean_score": sum(score_values) / len(score_values) if score_values else 0,
                "min_score": min(score_values) if score_values else 0,
                "max_score": max(score_values) if score_values else 0,
                "total_pairs": len(scores)
            }
            
            return {
                "model_name": model_name,
                "status": "completed",
                "scores": scores,
                "statistics": stats,
                "timestamp": datetime.now().isoformat()
            }
            
        except Exception as e:
            logger.error(f"模型 {model_name} 评分过程中出错: {e}")
            logger.error(traceback.format_exc())
            return {"model_name": model_name, "status": "error", "error": str(e)}
        
        finally:
            # 清理GPU内存
            if model is not None:
                del model
            if tokenizer is not None:
                del tokenizer
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    def save_results(self, results: Dict[str, Any]):
        """保存评分结果"""
        model_name = results["model_name"]
        output_file = os.path.join(self.output_dir, f"{model_name}.json")
        
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(results, f, ensure_ascii=False, indent=2)
            logger.info(f"评分结果已保存到: {output_file}")
        except Exception as e:
            logger.error(f"保存结果失败: {e}")
