# Copyright 2025 The HuggingFace Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
# os.environ["WANDB_DISABLED"] = "true"
import re
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional

from vllm_grpo_trainer import llama3GRPOVLLMTrainer
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from llama.Cardio_llama import Cardio_LLaMA
from data.dataset import get_rl_dataset

@dataclass
class MyArguments:
    """
    Arguments for the ImageBind-LLM pre-training script.

    Args:
        batch_size (`int`): Batch size per GPU.
        llama_type (`str`): Type of LLaMA model.
        llama_path (`str`): Path to LLaMA pretrained checkpoint.
        vit_path (`str`): Path to ViT pretrained checkpoint.
        max_words (`int`): Max number of input words.
        model_ckpt (`str`): Path to the pretrained checkpoint.
        nli_model_name (`str`): Name of the pre-trained NLI model for consistency reward.
    """

    batch_size: int = field(
        default=64,
        metadata={"help": "Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus)"},
    )
    llama_type: str = field(
        default="llama3",
        metadata={"help": "Type of LLaMA model"},
    )
    llama_path: str = field(
        default="/path/to/llama",
        metadata={"help": "Path to LLaMA pretrained checkpoint"},
    )
    vit_path: str = field(
        default="google/vit-base-patch16-224",
        metadata={"help": "Path to ViT pretrained checkpoint"},
    )
    max_words: int = field(
        default=512,
        metadata={"help": "Max number of input words"},
    )
    model_ckpt: str = field(
        default="./ckpts/all/ep8/checkpoint.pth",
        metadata={"help": "Path to the pretrained checkpoint"},
    )
    nli_model_name: str = field(
        default="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
        metadata={"help": "Name of the pre-trained NLI model for consistency reward"},
    )

# 全局变量用于缓存NLI模型
_nli_model = None
_nli_tokenizer = None

def get_nli_model(model_name):
    """获取或初始化NLI模型和tokenizer"""
    global _nli_model, _nli_tokenizer
    if _nli_model is None:
        print(f"Loading NLI model: {model_name}")
        _nli_tokenizer = AutoTokenizer.from_pretrained(model_name)
        _nli_model = AutoModelForSequenceClassification.from_pretrained(model_name)
        _nli_model.eval()
        # 将模型移到GPU（如果可用）
        if torch.cuda.is_available():
            _nli_model = _nli_model.cuda()
    return _nli_model, _nli_tokenizer

def reasoning_consistency_reward(completions, ground_truths, nli_model_name="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"):
    """
    基于预训练NLI模型评估推理与诊断结论的一致性
    """
    model, tokenizer = get_nli_model(nli_model_name)
    rewards = []
    
    for completion in completions:
        # 提取推理和诊断部分
        reasoning_match = re.search(r"<reasoning>(.*?)</reasoning>", completion, re.DOTALL)
        diagnosis_match = re.search(r"<diagnosis>(.*?)</diagnosis>", completion, re.DOTALL)
        
        if not reasoning_match or not diagnosis_match:
            rewards.append(0.0)
            continue
            
        reasoning_text = reasoning_match.group(1).strip()
        diagnosis_text = diagnosis_match.group(1).strip()
        
        # 清理格式标记
        reasoning_text = re.sub(r"\*\*", "", reasoning_text).strip()
        diagnosis_text = re.sub(r"\*\*", "", diagnosis_text).strip()
        
        # 构建NLI输入：推理作为前提，诊断作为假设
        premise = reasoning_text
        hypothesis = f"The diagnosis is {diagnosis_text}"
        
        # 确保文本长度合理
        if len(premise) < 10 or len(hypothesis) < 5:
            rewards.append(0.5)  # 文本太短，给中性分数
            continue
            
        # Tokenize输入
        inputs = tokenizer(
            premise, 
            hypothesis, 
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True
        )
        
        # 将输入移到GPU（如果模型在GPU上）
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        try:
            # 前向传播（不计算梯度）
            with torch.no_grad():
                outputs = model(**inputs)
                logits = outputs.logits
                probabilities = torch.softmax(logits, dim=1)[0]
                
                # 获取entailment（蕴含）概率
                entailment_prob = probabilities[2].item()  # 索引2对应entailment
                contradiction_prob = probabilities[0].item()  # 索引0对应contradiction
                
            # 根据概率计算奖励分数
            if entailment_prob > 0.7:  # 强蕴含
                score = 1.0
            elif contradiction_prob > 0.7:  # 强矛盾
                score = 0.0
            else:  # 中性或弱关系
                score = 0.5
                
            rewards.append(score)
        except Exception as e:
            print(f"NLI模型推理错误: {e}")
            rewards.append(0.5)  # 出错时给中性分数
    
    return rewards

def structured_format_reward(completions, ground_truths):
    """
    检查输出是否包含必需的三个部分：
    1. <reasoning> - 推理过程
    2. <diagnosis> - 诊断结论
    3. <confidence> - 置信度评估
    
    参数:
        completions: 模型生成的诊断报告列表
        ground_truths: 真实诊断标签列表
    
    返回:
        格式合规性奖励 (0.0或1.0)
    """
    rewards = []
    
    for completion in completions:
        # 检查三个必需部分是否存在
        has_reasoning = re.search(r"<reasoning>.*?</reasoning>", completion, re.DOTALL) is not None
        has_diagnosis = re.search(r"<diagnosis>.*?</diagnosis>", completion, re.DOTALL) is not None
        has_confidence = re.search(r"<confidence>.*?</confidence>", completion, re.DOTALL) is not None
        
        # 检查置信度格式 (高/中/低)
        confidence_valid = False
        if has_confidence:
            confidence_match = re.search(r"<confidence>(.*?)</confidence>", completion, re.DOTALL)
            if confidence_match:
                confidence_level = confidence_match.group(1).strip()
                confidence_valid = confidence_level in ["High", "Medium", "Low"]
        
        # 仅当所有部分都存在且置信度格式正确时给予奖励
        rewards.append(1.0 if (has_reasoning and has_diagnosis and has_confidence and confidence_valid) else 0.0)
    
    return rewards

def NIA_AA_reward(completions, ground_truths):
    """
    根据NIA-AA标准评估AD诊断准确性
    
    参数:
        completions: 模型生成的诊断报告列表
        ground_truths: 真实诊断报告列表（包含完整格式）
    
    返回:
        奖励值列表 (0.0-1.0)
    """
    rewards = []
    
    # NIA-AA诊断标准映射
    niaaa_criteria = {
        "CN": ["CN", "Cognitively Normal", "Normal"],
        "MCI": ["MCI", "Mild Cognitive Impairment"], 
        "Dementia": ["Dementia", "Alzheimer's Disease"],
    }
    
    # 从真实诊断报告中提取诊断标签的辅助函数
    def extract_diagnosis_from_report(report):
        """从诊断报告中提取诊断标签"""
        diagnosis_match = re.search(r"<diagnosis>(.*?)</diagnosis>", report, re.DOTALL)
        if diagnosis_match:
            diagnosis_text = diagnosis_match.group(1).strip()
            # 清理格式标记
            diagnosis_text = re.sub(r"\*\*", "", diagnosis_text).strip()
            
            # 映射到标准诊断标签
            for standard_label, keywords in niaaa_criteria.items():
                if any(keyword.lower() in diagnosis_text.lower() for keyword in keywords):
                    return standard_label
        return None
    
    for completion, ground_truth_report in zip(completions, ground_truths):
        # 从真实诊断报告中提取诊断标签
        true_diagnosis = extract_diagnosis_from_report(ground_truth_report)
        if true_diagnosis is None:
            rewards.append(0.0)
            continue
            
        # 从模型生成内容中提取诊断部分
        diagnosis_match = re.search(r"<diagnosis>(.*?)</diagnosis>", completion, re.DOTALL)
        if not diagnosis_match:
            rewards.append(0.0)
            continue
            
        model_diagnosis_text = diagnosis_match.group(1).strip()
        # 清理格式标记
        model_diagnosis_text = re.sub(r"\*\*", "", model_diagnosis_text).strip()
        
        reward = 0.0
        
        # 1. 诊断类别匹配 (40%权重)
        model_diagnosis_label = None
        for standard_label, keywords in niaaa_criteria.items():
            if any(keyword.lower() in model_diagnosis_text.lower() for keyword in keywords):
                model_diagnosis_label = standard_label
                break
        
        if model_diagnosis_label == true_diagnosis:
            reward += 0.4
        
        # 2. 生物标志物一致性 (30%权重)
        biomarkers = ["Abeta", "pTau", "tTau", "amyloid", "tau"]
        mentioned_biomarkers = sum(1 for marker in biomarkers if marker.lower() in completion.lower())
        reward += 0.3 * (mentioned_biomarkers / len(biomarkers))
        
        # 3. 临床特征覆盖度 (30%权重)
        clinical_features = ['memory', 'executive', 'visuospatial', 'cognitive']
        covered_features = sum(1 for feature in clinical_features if feature.lower() in completion.lower())
        reward += 0.3 * (covered_features / len(clinical_features))
        
        rewards.append(min(reward, 1.0))  # 确保不超过1.0
    
    return rewards

# 奖励函数注册表
reward_funcs_registry = {
    "Jaccard": NIA_AA_reward,
    "format": structured_format_reward,
    "consistency": reasoning_consistency_reward,
}

def main(training_args):
    # 获取奖励函数
    reward_funcs = ['Jaccard', 'format', 'consistency']
    
    # 为consistency奖励函数添加额外的参数
    def consistency_wrapper(completions, ground_truths):
        return reasoning_consistency_reward(completions, ground_truths, training_args.nli_model_name)
    
    # 替换注册表中的consistency函数
    reward_funcs_registry["consistency"] = consistency_wrapper
    
    reward_funcs = [reward_funcs_registry[func] for func in reward_funcs]

    llama_type = training_args.llama_type
    if llama_type == 'llama3':
        training_args.llama_ckpt_dir = os.path.join(training_args.llama_path)
        training_args.llama_tokenzier_path = os.path.join(training_args.llama_path,'tokenizer.model')
        
    model = Cardio_LLaMA(training_args.llama_ckpt_dir, training_args.llama_tokenzier_path, training_args, stage=3, load_llama=False)
    print(model)
    print("Loading Model Checkpoint")
    checkpoint = torch.load(training_args.model_ckpt, map_location='cpu', weights_only=False)

    new_ckpt = {}
    for key, value in checkpoint['model'].items():
        key = key.replace("module.", "")
        new_ckpt[key] = value
    del checkpoint
    load_result = model.load_state_dict(new_ckpt, strict=False)
    model.to("cuda")
    tokenizer = model.tokenizer

    train_datalist = [
        '{}-train'.format('ADNI'),
    ]

    task = 'CNvsCI'
    dataset = get_rl_dataset(train_datalist, task, tokenizer, max_words=300)
    eval_datalist = [
        '{}-train'.format('ADNI'),
    ]

    task = 'CNvsCI'
    eval_dataset = get_rl_dataset(eval_datalist, task, tokenizer, max_words=300)
    
    trainer_cls = llama3GRPOVLLMTrainer
    print("using: ", trainer_cls)

    # Initialize the GRPO trainer
    trainer = trainer_cls(
        model=model,
        processing_class=tokenizer,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=eval_dataset,
        eval_dataset=eval_dataset,
    )

    # Train and push the model to the Hub
    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)


if __name__ == "__main__":
    parser = TrlParser((GRPOConfig, MyArguments))
    training_args, args = parser.parse_args_and_config()
    for key, value in vars(args).items():
        setattr(training_args, key, value)
    main(training_args)