from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from accelerate import Accelerator
import numpy as np
import torch
import torch.nn as nn
from base_trainer import RewardTrainer
from transformers.utils import PaddingStrategy
from transformers import AutoTokenizer
from transformers.trainer_utils import EvalLoopOutput, denumpify_detensorize
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from transformers.trainer_utils import PredictionOutput
from transformers import Trainer

import torch
import torch.nn as nn
from transformers import LlamaPreTrainedModel, LlamaModel, LlamaConfig
from typing import Optional, Tuple

from transformers import PreTrainedModel, PretrainedConfig, AutoModelForSequenceClassification


@dataclass
class RewardDataCollatorWithPadding:
    tokenizer: AutoTokenizer
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        merged_features = []
        margins = []
        
        for feature in features:
            merged_features.append(
                {
                    "input_ids": feature["input_ids_chosen"],
                    "attention_mask": feature["attention_mask_chosen"],
                }
            )
            merged_features.append(
                {
                    "input_ids": feature["input_ids_rejected"],
                    "attention_mask": feature["attention_mask_rejected"],
                }
            )
            if 'margin' in feature.keys():
                margins.append(feature['margin'])
        batch = self.tokenizer.pad(
            merged_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
        batch = {
            "input_ids": batch["input_ids"],
            "attention_mask": batch["attention_mask"],
            "return_loss": True,
            "margin": margins,
        }
        return batch


class SharedTrunkPoERewardModel(LlamaPreTrainedModel):
    """
    一个使用共享主干网络和双头的 PoE Reward 模型。
    - score: 学习去偏后的奖励。
    - bias_head: 学习偏见，其输入会加入噪声。
    """
    def __init__(self, config: LlamaConfig, noise_magnitude: float = 0.001):
        super().__init__(config)
        
        self.model = LlamaModel(config)
        
        self.score = nn.Linear(config.hidden_size, 1, bias=False)
        self.bias_head = nn.Linear(config.hidden_size, 1, bias=False)
        
        self.noise_magnitude = noise_magnitude

        initializer_range = 0.02 
        self.score.weight.data.normal_(mean=0.0, std=initializer_range)
        self.bias_head.weight.data.normal_(mean=0.0, std=initializer_range)
        
        self.post_init()

    def get_last_token_hidden_state(self, hidden_states, attention_mask):
        # hidden_states: [B,T,H]; attention_mask: [B,T]
        lengths = attention_mask.long().sum(dim=1) - 1
        lengths = lengths.clamp(min=0)
        batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
        
        return hidden_states[batch_idx, lengths]

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.LongTensor,
        **kwargs,
    ) -> Tuple[torch.Tensor]:
        
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs,
        )
        
        last_hidden_state = outputs.last_hidden_state
        shared_representation = self.get_last_token_hidden_state(last_hidden_state, attention_mask)

        # is_main_process = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or not torch.distributed.is_initialized()

        noisy_representation = shared_representation

        reward_main = self.score(shared_representation)

        if self.training:
            noise = torch.randn_like(shared_representation) * self.noise_magnitude
            noisy_representation = shared_representation + noise

        reward_bias = self.bias_head(noisy_representation)
        
        if self.training:
            return (reward_main + reward_bias, )
            # return (reward_main, )
        else:
            return (reward_main, )


class PoETrainer(RewardTrainer):
    def __init__(self, **kwargs):
        self.main_lr = 2e-6
        self.bias_lr = 6e-6
        super(PoETrainer, self).__init__(**kwargs)
        # if torch.distributed.get_rank() == 0:
        #     import debugpy
        #     print("Rank 0: Waiting for debugger to attach on port 56788...", flush=True)
        #     debugpy.listen(56788)
        #     debugpy.wait_for_client()
        #     print("Debugger attached!", flush=True)
        # # 确保所有进程都等待调试器连接成功
        # torch.distributed.barrier()
        
    def create_optimizer(self):
        model = self.model
        
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if "model." in n and p.requires_grad],
                "lr": self.main_lr,
            },
            {
                "params": [p for n, p in model.named_parameters() if "score" in n and p.requires_grad],
                "lr": self.main_lr,
            },
            {
                "params": [p for n, p in model.named_parameters() if "bias_head" in n and p.requires_grad],
                "lr": self.bias_lr,
            },
        ]
        for group in optimizer_grouped_parameters:
            group["weight_decay"] = self.args.weight_decay

        optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
        optimizer_kwargs.pop("lr", None)
        self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
        
        return self.optimizer

    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[list[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
    # 对于RewardBench-V2，可以在这里重写
        output = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        return output

    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str = "Evaluating",
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[list[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> PredictionOutput:
        output = super().evaluation_loop(
            dataloader,
            description,
            prediction_loss_only,
            ignore_keys,
            metric_key_prefix,
        )

        rewards = output.predictions
        rewards_j = output.predictions[:,0]  # 已经是gather好的
        rewards_k = output.predictions[:,1]  
        
        if self.accelerator.is_main_process:
            acc = (rewards_j > rewards_k).astype(np.float16).mean()
            reward_diff = (rewards_j - rewards_k).astype(np.float16).mean()
            metrics = {
                f"{metric_key_prefix}_accuracy": acc.item(),
                f"{metric_key_prefix}_avg_rewards_chosen": rewards_j.mean().item(),
                f"{metric_key_prefix}_avg_rewards_rejected": rewards_k.mean().item(),
                f"{metric_key_prefix}_avg_margin": reward_diff.item(),
            }
        else:
            metrics = {}

        output.metrics.update(metrics)

        return output

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        
        rewards = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])[0]

        bsz = rewards.size(0)
        jidx = torch.arange(0, bsz, 2)
        kidx = jidx + 1
        rewards_j = rewards[jidx]
        rewards_k = rewards[kidx]

        loss = - nn.functional.logsigmoid(rewards_j - rewards_k).mean() 

        if return_outputs:
            return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
        return loss