# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch.nn as nn
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from transformers import Trainer, DataCollatorForSeq2Seq
from typing import Dict, List, Optional, Sequence, Tuple, Union, Any
import numpy as np
from numpy.typing import NDArray

IGNORE_INDEX = -100

# ==========================================
# 常量定义：Prompt 和 Suffix
# ==========================================
# 保持与您的推理环境一致
EVAL_SUFFIX = """

Question: Evaluate if the taking the proposed action in the current state is logically heading in the correct direction for completing the task. Provide an answer of helpful or unhelpful.
Answer: This action is"""

# 定义我们期望的回答词
HELPFUL_WORD = " helpful"
UNHELPFUL_WORD = " unhelpful"

@dataclass
class ComputeAccuracy:
    def numpify(self, inputs: Union["NDArray", "torch.Tensor"]):
        if isinstance(inputs, torch.Tensor):
            inputs = inputs.cpu().numpy()
        return inputs

    def __call__(self, eval_preds, compute_result: bool = True) -> Optional[Dict[str, float]]:
        # Positive 样本的 Helpful 概率 > Negative 样本的 Helpful 概率
        return {"accuracy": np.mean(eval_preds.predictions[0] > eval_preds.predictions[1])}

@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
    r"""
    Data collator that handles proper padding for both the main input (Helpful)
    and the contrastive input (Unhelpful).
    """
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
        # 1. 拆分 Feature
        main_features = []    # 存放 "helpful" 版本的序列
        contras_features = [] # 存放 "unhelpful" 版本的序列
        factors = []
        
        for key in ("pos", "neg"):
            for feature in features:
                # 提取 Main (Helpful)
                main_features.append({
                    "input_ids": feature[f"{key}_input_ids"],
                    "attention_mask": feature[f"{key}_attention_mask"],
                    "labels": feature[f"{key}_labels"],
                })
                # 提取 Contras (Unhelpful)
                contras_features.append({
                    "input_ids": feature[f"{key}_contras_input_ids"],
                    "attention_mask": feature[f"{key}_contras_attention_mask"],
                    "labels": feature[f"{key}_contras_labels"],
                })
                factors.append(feature["factor"])
                
        # 2. 分别调用父类进行 Pad (batch 内最长)
        # 这样 input_ids 和 contras_input_ids 的长度可能不同，但 batch 维度一致，这是可以的
        batch_main = super().__call__(main_features)
        batch_contras = super().__call__(contras_features)
        
        # 3. 合并
        batch = {
            "input_ids": batch_main["input_ids"],
            "attention_mask": batch_main["attention_mask"],
            "labels": batch_main["labels"],
            "contras_input_ids": batch_contras["input_ids"],
            "contras_attention_mask": batch_contras["attention_mask"],
            "contras_labels": batch_contras["labels"],
            "factor": torch.tensor(factors, dtype=torch.float32)
        }
        
        return batch

class RewardModelWithProb(nn.Module):
    r"""
    Calculates P(Helpful) = exp(AvgLogProb_Helpful) / (exp(AvgLogProb_Helpful) + exp(AvgLogProb_Unhelpful))
    AvgLogProb is calculated via CrossEntropyLoss(reduction='mean') on the suffix.
    """
    def __init__(self, pretrained_model, **kwargs):
        super().__init__()
        self.pretrained_model = pretrained_model
        self.config = pretrained_model.config
        
        if hasattr(pretrained_model, "gradient_checkpointing_disable"):
            self.gradient_checkpointing_disable = pretrained_model.gradient_checkpointing_disable
        if hasattr(pretrained_model, "gradient_checkpointing_enable"):
            self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        labels=None,
        contras_input_ids=None,
        contras_attention_mask=None,
        contras_labels=None,
        return_past_key_values=False,
        **kwargs,
    ):
        if "factor" in kwargs:
            kwargs.pop('factor')
            
        # 注意：Training 时我们会传入 labels，这会触发 CausalLM 内部计算 Loss
        # Loss = Negative Log Likelihood
        
        # 1. 计算 "Helpful" 序列的 Avg LogLikelihood
        outputs_pos = self.pretrained_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels, # 传入 labels 自动计算 loss
            output_hidden_states=False,
            return_dict=True,
            **kwargs, 
        )
        # CELoss 返回的是所有有效 token loss 的平均值
        # avg_log_likelihood = -loss
        avg_log_pos = -outputs_pos.loss 

        # 2. 计算 "Unhelpful" 序列的 Avg LogLikelihood
        outputs_neg = self.pretrained_model(
            input_ids=contras_input_ids,
            attention_mask=contras_attention_mask,
            labels=contras_labels,
            output_hidden_states=False,
            return_dict=True,
            **kwargs, 
        )
        avg_log_neg = -outputs_neg.loss
        
        # 3. 计算概率公式
        # prob = exp(pos) / (exp(pos) + exp(neg))
        # 为了数值稳定性，通常使用 softmax
        # stack shape: [Batch, 2] -> (pos, neg)
        #   注意：这里的 pos/neg 指的是 helpful/unhelpful 两个词
        #   但 outputs_pos.loss 返回的是 Scalar (Batch Mean)，这不行！
        #   我们需要 Per-Sample 的 Loss
        #   但是 HF 的 model forward 默认返回 scalar loss。
        
        # === 修正：为了得到 Per-Sample Loss，我们不能直接用 model.loss ===
        # 我们需要自己计算 logits 对应的 loss，不进行 reduction
        
        # Re-calc manually to get per-sample loss
        loss_fct = nn.CrossEntropyLoss(reduction='none', ignore_index=IGNORE_INDEX)
        
        # Helpful Path
        logits_pos = outputs_pos.logits
        # Shift labels
        shift_logits_pos = logits_pos[..., :-1, :].contiguous()
        shift_labels_pos = labels[..., 1:].contiguous()
        loss_pos_per_token = loss_fct(shift_logits_pos.view(-1, self.config.vocab_size), shift_labels_pos.view(-1))
        loss_pos_per_token = loss_pos_per_token.view(shift_labels_pos.shape)
        # 求平均 (sum / count)
        # count 是非 ignore 的 token 数量
        non_pad_pos = (shift_labels_pos != IGNORE_INDEX).sum(dim=1)
        avg_log_pos = -(loss_pos_per_token.sum(dim=1) / non_pad_pos.clamp(min=1))

        # Unhelpful Path
        logits_neg = outputs_neg.logits
        shift_logits_neg = logits_neg[..., :-1, :].contiguous()
        shift_labels_neg = contras_labels[..., 1:].contiguous()
        loss_neg_per_token = loss_fct(shift_logits_neg.view(-1, self.config.vocab_size), shift_labels_neg.view(-1))
        loss_neg_per_token = loss_neg_per_token.view(shift_labels_neg.shape)
        
        non_pad_neg = (shift_labels_neg != IGNORE_INDEX).sum(dim=1)
        avg_log_neg = -(loss_neg_per_token.sum(dim=1) / non_pad_neg.clamp(min=1))

        # 4. 计算 Prob
        # stack: [Batch, 2]
        logits_stack = torch.stack([avg_log_pos, avg_log_neg], dim=1)
        probs = F.softmax(logits_stack, dim=1)
        
        # 取 index 0 (Helpful) 的概率作为 value
        value = probs[:, 0]

        if return_past_key_values:
            return (value, outputs_pos.past_key_values)
        else:
            return value

def preprocess_value_dataset(
    examples,
    tokenizer,
    max_length=2048,
):
    model_inputs = {
        "pos_input_ids": [], "pos_attention_mask": [], "pos_labels": [],
        "pos_contras_input_ids": [], "pos_contras_attention_mask": [], "pos_contras_labels": [],
        
        "neg_input_ids": [], "neg_attention_mask": [], "neg_labels": [],
        "neg_contras_input_ids": [], "neg_contras_attention_mask": [], "neg_contras_labels": [],
        "factor": []
    }

    def process_single_seq(prompt_part, action_part, suffix_word):
        # 1. 构造完整文本
        # 注意：这里我们让 EVAL_SUFFIX + 单词 拼接
        full_text = prompt_part + action_part + EVAL_SUFFIX + suffix_word
        # 2. 构造 Prefix (用于计算 mask 边界)
        prefix_text = prompt_part + action_part + EVAL_SUFFIX
        
        # Tokenize
        full_ids = tokenizer.encode(full_text, add_special_tokens=False)
        prefix_ids = tokenizer.encode(prefix_text, add_special_tokens=False)
        
        prefix_len = len(prefix_ids)
        
        # 截断
        if len(full_ids) > max_length:
            full_ids = full_ids[:max_length]
        
        # 构造 Labels
        labels = [IGNORE_INDEX] * len(full_ids)
        # 只有 prefix 之后的部分 (Helpful/Unhelpful) 才有 label
        # 注意：CausalLM 是用 input[t] 预测 label[t]，所以 label 需要对应 input 左移
        # 但 HuggingFace Trainer 会自动进行 shift，所以我们只需要把不想算 loss 的位置设为 -100
        # 我们希望算 loss 的部分是 suffix_word 的 tokens
        
        # 这里的 prefix_len 是 suffix_word 之前的长度
        # 所以从 prefix_len 开始到结尾，是我们关心的 tokens
        for k in range(prefix_len, len(full_ids)):
            labels[k] = full_ids[k]
            
        full_mask = [1] * len(full_ids)
        return full_ids, full_mask, labels

    for i in range(len(examples["prompt"])):
        prompt = examples["prompt"][i]
        pos_act = examples["pos"][i]
        neg_act = examples["neg"][i]
        
        # === 处理 Positive Action ===
        # 1. 生成 Helpful 版本 (Input A)
        pos_ids, pos_mask, pos_labels = process_single_seq(prompt, pos_act, HELPFUL_WORD)
        # 2. 生成 Unhelpful 版本 (Input B - Contrastive)
        pos_con_ids, pos_con_mask, pos_con_labels = process_single_seq(prompt, pos_act, UNHELPFUL_WORD)
        
        # === 处理 Negative Action ===
        # 3. 生成 Helpful 版本 (Input A)
        neg_ids, neg_mask, neg_labels = process_single_seq(prompt, neg_act, HELPFUL_WORD)
        # 4. 生成 Unhelpful 版本 (Input B - Contrastive)
        neg_con_ids, neg_con_mask, neg_con_labels = process_single_seq(prompt, neg_act, UNHELPFUL_WORD)
        
        # 权重
        neg_count = examples["neg_count"][i]
        pos_count = examples["pos_count"][i]
        if neg_count == 0 or pos_count == 0:
            factor = 1.0
        else:
            factor = 1 / (neg_count * pos_count)
            
        # Append
        model_inputs["pos_input_ids"].append(pos_ids)
        model_inputs["pos_attention_mask"].append(pos_mask)
        model_inputs["pos_labels"].append(pos_labels)
        model_inputs["pos_contras_input_ids"].append(pos_con_ids)
        model_inputs["pos_contras_attention_mask"].append(pos_con_mask)
        model_inputs["pos_contras_labels"].append(pos_con_labels)
        
        model_inputs["neg_input_ids"].append(neg_ids)
        model_inputs["neg_attention_mask"].append(neg_mask)
        model_inputs["neg_labels"].append(neg_labels)
        model_inputs["neg_contras_input_ids"].append(neg_con_ids)
        model_inputs["neg_contras_attention_mask"].append(neg_con_mask)
        model_inputs["neg_contras_labels"].append(neg_con_labels)
        
        model_inputs["factor"].append(factor)

    return model_inputs
    
class RMTrainer(Trainer):
    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.can_return_loss = True 

    def compute_loss(
        self, model, inputs: Dict[str, torch.Tensor], return_outputs: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:

        factor = inputs.get("factor", None)
        if factor is not None:
            del inputs["factor"]
        
        # inputs 已经被 flatten 过了 (PairwiseDataCollatorWithPadding)
        # input_ids: [Batch, SeqLen]. 前一半是 Pos-Helpful, 后一半是 Neg-Helpful
        # contras_input_ids: [Batch, SeqLen]. 前一半是 Pos-Unhelpful, 后一半是 Neg-Unhelpful
        
        # Forward Pass
        # 模型内部会同时计算 Helpful 和 Unhelpful 的 AvgLogProb，并返回 P(Helpful)
        values = model(**inputs, output_hidden_states=False, return_dict=False, use_cache=False)
        # values shape: [Batch]
        
        batch_size = values.shape[0] // 2
        chosen_probs, rejected_probs = torch.split(values, batch_size, dim=0)
        factor, _ = torch.split(factor, batch_size, dim=0)
        
        # Loss: -log(sigmoid( P(Good) - P(Bad) ))
        # 这里的 P 是绝对概率 (0~1)
        loss = -F.logsigmoid(chosen_probs - rejected_probs)
        
        weighted_loss = loss * factor
        final_loss = weighted_loss.sum()

        if return_outputs:
            return final_loss, (final_loss, chosen_probs, rejected_probs)
        return final_loss