import itertools
import torch
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.utils.device import get_device_name
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.ulysses import gather_outpus_and_unpad

from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm, implicit_drm_loss
from .utils import ulysses_pad_and_slice_inputs, chunked_logprobs_from_logits
__all__ = ["DataParallelDistributionLevelRewardModel"]


# MODIFIED FROM `DataParallelPRIMERewardModel`
class DataParallelDistributionLevelRewardModel:
    def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer):
        self.config = config
        self.reward_module = reward_module
        self.ref_module = ref_module
        self.reward_optimizer = reward_optimizer
        self.use_remove_padding = self.config.model.get("use_remove_padding", False)
        print(f"Reward model use_remove_padding={self.use_remove_padding}")
        self.use_fused_kernels = self.config.model.get("use_fused_kernels", False)
        print(f"Reward model use_fused_kernels={self.use_fused_kernels}")

        self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)

    def _forward_micro_batch(self, micro_batch):
        input_ids = micro_batch["input_ids"]
        attention_mask = micro_batch["attention_mask"]
        position_ids = micro_batch["position_ids"]
        batch_size, max_seq_length = input_ids.shape
        batch_size, max_response_length = micro_batch['responses'].shape
        max_prompt_length = max_seq_length - max_response_length
        max_positions = attention_mask[:, max_prompt_length:].sum(-1)
        old_log_prob_topk_indices = micro_batch['old_log_prob_topk_indices']
        
        if self.use_remove_padding:
            input_ids_rmpad, indices, *_ = unpad_input(
                input_ids.unsqueeze(-1), attention_mask
            )  # input_ids_rmpad (total_nnz, ...)
            input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)

            # unpad the position_ids to align the rotary
            position_ids_rmpad = index_first_axis(
                rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
            ).transpose(0, 1)

            # for compute the log_prob
            input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)

            # pad and slice the inputs if sp > 1
            if self.ulysses_sequence_parallel_size > 1:
                input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
                    input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size
                )
                input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
                    input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size
                )

            input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)
            output = self.reward_module(
                input_ids=input_ids_rmpad,
                attention_mask=None,
                position_ids=position_ids_rmpad,
                use_cache=False,
                return_dict=True,
            )
            rm_output_logits = output.logits.squeeze(0)
            rm_log_labels = verl_F.logprobs_from_logits(logits=rm_output_logits, labels=input_ids_rmpad_rolled)

            if self.ulysses_sequence_parallel_size > 1:
                rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size)
            rm_log_labels = pad_input(
                hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=max_seq_length
            ).squeeze(-1)[:, -max_response_length - 1 : -1]

            # SUPPORT Distribution-Level Adv
            old_log_prob_topk_indices_rmpad, *_ = unpad_input(old_log_prob_topk_indices.unsqueeze(-1), attention_mask)
            old_log_prob_topk_indices_rmpad = old_log_prob_topk_indices_rmpad.permute(2, 0, 1)
            old_log_prob_topk_indices_rmpad_rolled = torch.roll(old_log_prob_topk_indices_rmpad, shifts=-1, dims=1) 
            if self.ulysses_sequence_parallel_size > 1:
                old_log_prob_topk_indices_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(old_log_prob_topk_indices_rmpad_rolled, None, self.ulysses_sequence_parallel_size)
                rm_log_prob_topk_values = chunked_logprobs_from_logits(rm_output_logits, old_log_prob_topk_indices_rmpad_rolled[0])
                rm_log_prob_topk_values = gather_outpus_and_unpad(rm_log_prob_topk_values, gather_dim=0, unpad_dim=0, padding_size=pad_size)
            else:
                old_log_prob_topk_indices_rmpad_rolled = old_log_prob_topk_indices_rmpad_rolled.squeeze(0)
                rm_log_prob_topk_values = chunked_logprobs_from_logits(rm_output_logits, old_log_prob_topk_indices_rmpad_rolled)

            if rm_log_prob_topk_values.dim() < 2:
                rm_log_prob_topk_values = rm_log_prob_topk_values.unsqueeze(-1)

            rm_log_prob_topk_values = pad_input(hidden_states=rm_log_prob_topk_values,
                                                indices=indices, 
                                                batch=batch_size,
                                                seqlen=max_seq_length)[:, -max_response_length - 1:-1]

        else:
            output = self.reward_module(
                input_ids=micro_batch["input_ids"],
                attention_mask=micro_batch["attention_mask"],
                position_ids=micro_batch["position_ids"],
                use_cache=False,
                return_dict=True,
            )
            rm_output_logits = output.logits
            rm_log_prob = torch.nn.functional.log_softmax(rm_output_logits[:, :-1, :], dim=-1)
            rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1)

            # SUPPORT Distribution-Level Adv
            rm_log_prob_topk_values = chunked_logprobs_from_logits(
                logits=rm_output_logits, 
                labels=torch.roll(old_log_prob_topk_indices, shifts=-1, dims=1)
                )[:, -max_response_length - 1:-1]
            
            if rm_log_prob_topk_values.dim() < 3:
                rm_log_prob_topk_values = rm_log_prob_topk_values.unsqueeze(-1)

        if self.ref_module is not None:
            # do not have to pad again
            with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):
                if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding:
                    ref_output = self.ref_module(
                        input_ids=input_ids_rmpad,
                        attention_mask=None,
                        position_ids=position_ids_rmpad,
                        use_cache=False,
                    )

                    ref_output_logits = ref_output.logits.squeeze(0)
                    ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits, labels=input_ids_rmpad_rolled)
                    ref_log_labels = gather_outpus_and_unpad(
                        ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size
                    )
                    ref_log_labels = pad_input(
                        hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=max_seq_length
                    ).squeeze(-1)[:, -max_response_length - 1 : -1]
                    
                    # ref logp for topk
                    ref_log_prob_topk_values = chunked_logprobs_from_logits(ref_output_logits, old_log_prob_topk_indices_rmpad_rolled[0])
                    ref_log_prob_topk_values = gather_outpus_and_unpad(ref_log_prob_topk_values, gather_dim=0, unpad_dim=0, padding_size=pad_size)

                    if ref_log_prob_topk_values.dim() < 2:
                        ref_log_prob_topk_values = ref_log_prob_topk_values.unsqueeze(-1)

                    ref_log_prob_topk_values = pad_input(hidden_states=ref_log_prob_topk_values,
                                                        indices=indices,
                                                        batch=batch_size,
                                                        seqlen=max_seq_length)[:, -max_response_length - 1: -1]
                else:
                    ref_output = self.ref_module(
                        input_ids=micro_batch["input_ids"],
                        attention_mask=micro_batch["attention_mask"],
                        position_ids=micro_batch["position_ids"],
                        use_cache=False,
                    )
                    ref_output_logits = ref_output.logits
                    ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits[:, :-1, :], labels=micro_batch["input_ids"][:, 1:].unsqueeze(-1))
                    ref_log_prob_topk_values = chunked_logprobs_from_logits(logits=ref_output_logits, labels=torch.roll(old_log_prob_topk_indices, shifts=-1, dims=1))[:, -max_response_length - 1:-1]
                    
                    if ref_log_prob_topk_values.dim() < 3:
                        ref_log_prob_topk_values = ref_log_prob_topk_values.unsqueeze(-1)
        
        # REWARD SCORE FOR REWARD MODEL TRAINING 
        elif self.config.ref_model_train == "ref":
            ref_log_labels = micro_batch["ref_log_prob"]
        elif self.config.ref_model_train == "old":
            ref_log_labels = micro_batch["old_log_prob"]
        else:
            raise NotImplementedError
        
        ref_log_labels = ref_log_labels.to(rm_log_labels.dtype)
        q = rm_log_labels[:, -max_response_length:] - ref_log_labels[:, -max_response_length:]  # this is actually diff of q for rm training
        for i in range(micro_batch["input_ids"].shape[0]):
            q[i, max_positions[i] :] = 0


        # REWARD SCORE FOR REINFORCEMENT LEARNING TRAINING 
        if self.ref_module is not None:
            assert 'ref_log_prob_topk_values' in locals(), \
                "ref_log_prob_topk_values should have been computed above when ref_module is not None"
        elif self.config.ref_model_infer == "ref":
            ref_log_labels = micro_batch["ref_log_prob"]
            ref_log_prob_topk_values = micro_batch["ref_log_prob_topk_values"]
        elif self.config.ref_model_infer == "old":
            ref_log_labels = micro_batch["old_log_prob"]
            ref_log_prob_topk_values = micro_batch["old_log_prob_topk_values"]
        else:
            raise NotImplementedError
        ref_log_labels = ref_log_labels.to(rm_log_labels.dtype)

        with torch.no_grad():
            # calculate the difference
            beta = self.config.model.get("beta_train", 0.05)
            distribution_level_reward = beta * (rm_log_prob_topk_values[:, -max_response_length:] - ref_log_prob_topk_values[:, -max_response_length:])
            token_level_reward = beta * (rm_log_labels[:, -max_response_length:] - ref_log_labels[:, -max_response_length:])

            # trim unnecessary logprobs 
            for i in range(micro_batch["input_ids"].shape[0]):
                distribution_level_reward[i, max_positions[i] :] = 0
                token_level_reward[i, max_positions[i] :] = 0

            # different granularity of prime
            token_level_score = torch.zeros_like(token_level_reward)
            for i in range(micro_batch["input_ids"].shape[0]):
                last_idx = int(max_positions[i].item()) - 1
                if self.config.prime_granularity == "token":
                    token_level_score[i, :last_idx] = token_level_reward[i, : last_idx]
                elif self.config.prime_granularity == "whole":
                    token_level_score[i, last_idx] = token_level_reward[i, : last_idx + 1].sum()
                else:
                    raise NotImplementedError
            
            distribution_level_score = torch.zeros_like(distribution_level_reward)
            for i in range(micro_batch["input_ids"].shape[0]):
                distribution_level_score[i, : max_positions[i] - 1] = distribution_level_reward[i, : max_positions[i] - 1]
        return token_level_score, q, distribution_level_score

    def _optimizer_step(self):
        assert self.config.model.optim.grad_clip is not None

        if isinstance(self.reward_module, FSDP):
            grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip)
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip
            )
        self.reward_optimizer.step()
        return grad_norm

    # def prime_norm(self, token_level_scores, response_mask):
    #     if self.config.prime_norm == "batch_norm":
    #         if self.config.seq_agg == "sum":
    #             reverse_cumsum = suffix_sum_with_mask(token_level_scores, response_mask)
    #             token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6)
    #         elif self.config.seq_agg == "mean":
    #             reverse_cumsum = suffix_mean_with_mask(token_level_scores, response_mask)
    #             token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6)
    #         elif self.config.seq_agg == "geo_mean":
    #             pass
    #             # reverse_cumsum = suffix_mean_with_mask(token_level_scores, response_mask)
    #             # token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6)
    #     return token_level_scores
    

    # def dl_norm(self, distribution_level_adv_values):
    #     if self.config.prime_norm == "batch_norm":
    #         distribution_level_adv_values = distribution_level_adv_values / (distribution_level_adv_values.abs().max() + 1e-6)
    #     return distribution_level_adv_values

    def compute_rm_score(self, data: DataProto):
        self.reward_module.eval()
        if self.ref_module is not None:
            self.ref_module.eval()
        micro_batch_size = data.meta_info["micro_batch_size"]
        select_keys = [
            "responses", "input_ids", "attention_mask", "position_ids", "acc", 
            "old_log_prob", "old_log_prob_topk_indices", "old_log_prob_topk_values", "response_mask",
            "ref_log_prob", "ref_log_prob_topk_values"]
        batch = data.select(batch_keys=select_keys).batch
        use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
        prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1]

        if use_dynamic_bsz:
            # split using dynamic bsz
            max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
        else:
            micro_batches = batch.split(micro_batch_size)

        rm_scores_lst = []
        q_lst = []
        distribution_level_adv_values_lst = []
        for micro_batch in micro_batches:
            with torch.no_grad():
                rm_score, q, distribution_level_adv_values = self._forward_micro_batch(micro_batch)
            rm_scores_lst.append(rm_score)
            q_lst.append(q)
            distribution_level_adv_values_lst.append(distribution_level_adv_values)
        rm_scores = torch.concat(rm_scores_lst, dim=0)
        q = torch.concat(q_lst, dim=0)
        distribution_level_adv_values = torch.concat(distribution_level_adv_values_lst, dim=0)
        # rm_scores = self.prime_norm(rm_scores, data.batch['response_mask'])
        # distribution_level_adv_values = self.dl_norm(distribution_level_adv_values)

        if use_dynamic_bsz:
            indices = list(itertools.chain.from_iterable(indices))
            assert len(indices) == rm_scores.size(0), f"{len(indices)} vs. {rm_scores.size()}"
            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
            rm_scores = rm_scores[revert_indices]
        return (
            rm_scores,
            q.detach(),
            distribution_level_adv_values,
            {
                "reward_model/reward": rm_scores[data.batch['response_mask'].bool()].mean().item(),
                "reward_model/raw_reward": q[data.batch['response_mask'].bool()].mean().item(),
            },
        )

    def update_rm(self, data: DataProto):
        # make sure we are in training mode
        self.reward_module.train()
        metrics = {}
        beta = self.config.model.get("beta_train", 0.05)
        select_keys = ["prompts", "input_ids", "responses", "attention_mask", "position_ids", "acc", 
            "old_log_prob", "old_log_prob_topk_indices", "old_log_prob_topk_values", "response_mask",
            "ref_log_prob", "ref_log_prob_topk_values"]
        n_samples = data.meta_info["n"]

        for key in ["Q_bc", "acc_bc"]:
            if key in data.batch.keys():
                select_keys.append(key)

        batch = data.select(batch_keys=select_keys).batch
        # Split to make minibatch iterator for updating the actor
        # See PPO paper for details. https://arxiv.org/abs/1707.06347
        
        # weight
        if self.config.model.use_loss_weight:
            grouped_accs = data.batch["acc"].reshape(-1, n_samples)  # [num_prompts, num_groups, n_samples]
            grouped_accs_mean = grouped_accs.mean(dim=-1, keepdim=True)  # [num_prompts, num_groups, 1]
            weight = torch.where(grouped_accs > 0.5, 1.0 - grouped_accs_mean, grouped_accs_mean)
            weight1 = weight.mean(dim=-1, keepdim=True)
            weight1[weight1 == 0] = 1
            weight = (weight / weight1).flatten()
            batch["loss_weight"] = weight
        else:
            batch["loss_weight"] = torch.ones_like(data.batch["acc"])
        
        dataloader = batch.split(self.config.mini_batch_size)
        rm_scores_lst = []
        q_lst = []
        distribution_level_adv_values_lst = []

        for batch_idx, data in enumerate(dataloader):
            # split batch into micro_batches
            mini_batch = data
            if self.config.use_dynamic_bsz:
                max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
                micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
            else:
                micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu)
                self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu

            self.reward_optimizer.zero_grad()

            for data in micro_batches:
                data = data.to(get_device_name())
                attention_mask = data["attention_mask"]
                acc = data["acc"]
                prompt_ids = data["prompts"]
                prompt_length = prompt_ids.shape[-1]
                response_mask = attention_mask[:, prompt_length:]
                rm_score, q, distribution_level_adv_values = self._forward_micro_batch(data)
                rm_scores_lst.append(rm_score)
                q_lst.append(q.detach())
                distribution_level_adv_values_lst.append(distribution_level_adv_values)
                
                if self.config.model.loss_type == "ce":
                    dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta, loss_weight=data["loss_weight"])
                elif self.config.model.loss_type == "dpo":
                    # the implementation of dpo is actually detached, which means we have to know the average
                    # value of w/l reward before the update.
                    dpo_loss = compute_detach_dpo_loss_rm(
                        q, acc, Q_bc=data["Q_bc"], acc_bc=data["acc_bc"], response_mask=response_mask, beta=beta, loss_weight=data["loss_weight"]
                    )
                # elif self.config.model.loss_type == "bon_acc":
                #     # change the original distribution of each sample to BoN distribution, then update reward model
                #     dpo_loss = compute_detach_dpo_loss_rm(
                #         q,
                #         acc,
                #         Q_bc=data["Q_bc"],
                #         acc_bc=data["acc_bc"],
                #         response_mask=response_mask,
                #         beta=beta,
                #         bon_mode="bon_acc",
                #     )
                # elif self.config.model.loss_type == "bon_rm":
                #     dpo_loss = compute_detach_dpo_loss_rm(
                #         q,
                #         acc,
                #         Q_bc=data["Q_bc"],
                #         acc_bc=data["acc_bc"],
                #         response_mask=response_mask,
                #         beta=beta,
                #         bon_mode="bon_rm",
                #     )
                elif self.config.model.loss_type == "drm":
                    dpo_loss = implicit_drm_loss(
                        q, 
                        acc, 
                        response_mask=response_mask, 
                        beta=beta, 
                        loss_weight=data["loss_weight"], 
                        gamma=self.config.model.drm_gamma
                        )
                else:
                    raise NotImplementedError

                data = {"reward_model/dpo_loss": dpo_loss.detach().item()}

                if self.config.use_dynamic_bsz:
                    # relative to the dynamic bsz
                    loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size)
                else:
                    loss = dpo_loss / self.gradient_accumulation
                loss.backward()

                append_to_dict(metrics, data)

            grad_norm = self._optimizer_step()
            data = {"reward_model/grad_norm": grad_norm.detach().item()}
            append_to_dict(metrics, data)
        self.reward_optimizer.zero_grad()

        rm_scores = torch.cat(rm_scores_lst, dim=0)
        q = torch.concat(q_lst, dim=0)
        distribution_level_adv_values = torch.concat(distribution_level_adv_values_lst, dim=0)

        response_mask = batch["response_mask"]
        # rm_scores = self.prime_norm(rm_scores, response_mask)
        # distribution_level_adv_values = self.dl_norm(distribution_level_adv_values)
        metrics.update(
            {
                "reward_model/reward": rm_scores[response_mask.bool()].mean().item(),
                "reward_model/raw_reward": q[response_mask.bool()].mean().item(),
            },
        )

        return rm_scores, distribution_level_adv_values, metrics


def suffix_sum_with_mask(score: torch.Tensor, eos_mask: torch.Tensor):
    # score: [B, L]
    # eos_mask: [B, L]，1=有效，0=padding
    assert score.shape == eos_mask.shape

    score = score * eos_mask  # 先把 padding 部分变成 0
    flipped_score = score.flip(dims=[1])
    suffix_sum = torch.cumsum(flipped_score, dim=1).flip(dims=[1])  # [B, L]

    return suffix_sum


def suffix_mean_with_mask(score: torch.Tensor, eos_mask: torch.Tensor):
    # score: [batch_size, seq_len]
    # eos_mask: [batch_size, seq_len], 1 for valid tokens, 0 for padding
    assert score.shape == eos_mask.shape
    
    # Flip 维度，在最后一个 token 上进行累加
    score_flipped = score.flip(dims=[1])
    mask_flipped = eos_mask.flip(dims=[1])

    # 只保留有效部分的累加和
    masked_score = score_flipped * mask_flipped
    cumsum = torch.cumsum(masked_score, dim=1).flip(dims=[1])  # [B, L]

    # 有效token的个数
    count = torch.cumsum(mask_flipped, dim=1).flip(dims=[1])  # [B, L]

    # 避免除以0（用1占位，不影响最终结果，因为分子也为0）
    count = torch.clamp(count, min=1)

    suffix_avg = cumsum / count

    # 为 padding 的位置重新置为 0（可选，根据需要也可以置为 NaN）
    suffix_avg = suffix_avg * eos_mask

    return suffix_avg


def suffix_geomean_with_mask(score: torch.Tensor, eos_mask: torch.Tensor):
    # score: [batch_size, seq_len], should be > 0 for geometric mean
    # eos_mask: [batch_size, seq_len], 1 for valid tokens, 0 for padding
    assert score.shape == eos_mask.shape
    assert (score > 0).all(), "Score must be positive for geometric mean"

    # Flip for suffix computation
    score_flipped = score.flip(dims=[1])
    mask_flipped = eos_mask.flip(dims=[1])

    # log domain for numeric stability
    log_score = torch.log(score_flipped) * mask_flipped  # invalid tokens → 0
    log_cumsum = torch.cumsum(log_score, dim=1).flip(dims=[1])  # [B, L]
    
    count = torch.cumsum(mask_flipped, dim=1).flip(dims=[1])  # valid token count
    count = torch.clamp(count, min=1)  # avoid division by zero

    # geometric mean in log space: exp(sum log x / n)
    suffix_geomean = torch.exp(log_cumsum / count)

    # mask out padding tokens
    suffix_geomean = suffix_geomean * eos_mask

    return suffix_geomean