import torch
from trl import GRPOTrainer
from ..utils.trainer_utils import *

class FKLTrainer(GRPOTrainer):

  def grpo_compute_loss_slow(
    ref_logits,
    new_logits,
    old_logits,
    input_ids,
    mask,
    beta,
    advantages,
    **kwargs
):
    # All Unsloth Zoo code licensed under LGPLv3
    # Set defaults for optional arguments
    loss_type = kwargs.get("loss_type", "grpo")
    epsilon_low = kwargs.get("epsilon_low", 0.2)
    epsilon_high = kwargs.get("epsilon_high", 0.2)
    max_completion_length = kwargs.get("max_completion_length", 8192)
    delta = kwargs.get("delta", None)
    temperature = kwargs.get("temperature", 1.0)
    logit_scale_multiply = kwargs.get("logit_scale_multiply", 0.0)
    logit_scale_divide   = kwargs.get("logit_scale_divide", 0.0)
    logit_softcapping    = kwargs.get("logit_softcapping", 0.0)
    importance_sampling_level = kwargs.get("importance_sampling_level", "token")

    input_ids = input_ids.unsqueeze(-1)

    # Optional logit softcapping and logit dividing
    if logit_scale_multiply != 0: new_logits = new_logits * logit_scale_multiply
    if logit_scale_divide   != 0: new_logits = new_logits / logit_scale_divide
    if logit_softcapping    != 0: new_logits = new_logits * torch.tanh(new_logits / logit_softcapping)

    new_logits = new_logits.to(torch.float32)
    # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
    if temperature != 1.0: new_logits = new_logits / temperature
    new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
    new = new_x - torch.logsumexp(new_logits, dim = -1)
    # x_i - logsumexp(x_i)
    with torch.no_grad():
        if beta != 0.0:
            assert ref_logits is not None, "ref_logits should not be None when beta != 0.0"

            # Optional logit softcapping and logit dividing
            if logit_scale_multiply != 0: ref_logits = ref_logits * logit_scale_multiply
            if logit_scale_divide   != 0: ref_logits = ref_logits / logit_scale_divide
            if logit_softcapping    != 0: ref_logits = ref_logits * torch.tanh(ref_logits / logit_softcapping)

            ref_logits = ref_logits.to(torch.float32)
            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
            if temperature != 1.0: ref_logits = ref_logits / temperature
            ref_x = torch.gather(ref_logits, dim = -1, index = input_ids).squeeze(-1)
            ref = ref_x - torch.logsumexp(ref_logits, dim = -1)
        pass

        if old_logits is not None:
            # Optional logit softcapping and logit dividing
            if logit_scale_multiply != 0: old_logits = old_logits * logit_scale_multiply
            if logit_scale_divide   != 0: old_logits = old_logits / logit_scale_divide
            if logit_softcapping    != 0: old_logits = old_logits * torch.tanh(old_logits / logit_softcapping)

            old_logits = old_logits.to(torch.float32)
            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
            if temperature != 1.0: old_logits = old_logits / temperature
            old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
            old = old_x - torch.logsumexp(old_logits, dim = -1)
        pass
    pass

    # Forward KL
    # Note that this is a low variance low bias estimator for the KL divergence as used in GRPO paper
    if beta != 0.0:
        kl_i = torch.exp(new - ref) - (new - ref) - 1.0
    else:
        # set kl_i to a tensor of zeros with the correct shape
        if importance_sampling_level == "sequence":
            kl_i = new.new_zeros(new.size(0), 1)
        else:
            kl_i = torch.zeros_like(new)
    # Full correct reverse KL divergence?? Missing term maybe?
    # kl_i = torch.exp(new) * kl_i

    # Below is forward KL (normal KL)
    # kl_i = torch.exp(old) * (old - new)
    if old_logits is not None:
        log_ratio = new - old
    else:
        log_ratio = new - new.detach()

    if importance_sampling_level == "token":
        log_importance_weights = log_ratio
    elif importance_sampling_level == "sequence":
        log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
        log_importance_weights = log_importance_weights.unsqueeze(-1)
    else:
        raise ValueError(
            f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
            "and 'sequence'."
        )

    coef_1 =  torch.exp(log_importance_weights)

    coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)

    if delta is not None:
        loss_1 = torch.clamp(coef_1, max=delta) * advantages.unsqueeze(1)
    else:
        loss_1 = coef_1 * advantages.unsqueeze(1)
    pass

    # Must detach - otherwise gradients are not propagated correctly!
    # exp(x - x) == 1
    # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)

    loss_2 = coef_2 * advantages.unsqueeze(1)
    loss_i = -torch.min(loss_1, loss_2)
    if beta != 0.0:
        loss_i = loss_i + beta * kl_i

    mask = mask.to(torch.float32)
    n_mask_per_reward = mask.sum(1)

    # https://github.com/huggingface/trl/blob/e8b8499f1f8d76838155b515e414ee98f757d6d5/trl/trainer/grpo_trainer.py#L1624
    if loss_type == "grpo":
        loss = ((loss_i * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
    elif loss_type == "bnpo":
        loss = (loss_i * mask).sum() / mask.sum().clamp(min=1.0)
    elif loss_type == "dr_grpo":
        loss = (loss_i * mask).sum() / (loss_i.size(0) * max_completion_length)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")

    # loss = (loss_i * mask).sum() / mask.sum()

class BranchTrainer(GRPOTrainer):
  def __init__(self, compressor:EntropyCompressor, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.compressor = compressor

  def grpo_compute_loss_slow(self,
    ref_logits,
    new_logits,
    old_logits,
    input_ids,
    mask,
    beta,
    advantages,
    **kwargs
):
    # All Unsloth Zoo code licensed under LGPLv3
    # Set defaults for optional arguments
    loss_type = kwargs.get("loss_type", "grpo")
    epsilon_low = kwargs.get("epsilon_low", 0.2)
    epsilon_high = kwargs.get("epsilon_high", 0.2)
    max_completion_length = kwargs.get("max_completion_length", 8192)
    delta = kwargs.get("delta", None)
    temperature = kwargs.get("temperature", 1.0)
    logit_scale_multiply = kwargs.get("logit_scale_multiply", 0.0)
    logit_scale_divide   = kwargs.get("logit_scale_divide", 0.0)
    logit_softcapping    = kwargs.get("logit_softcapping", 0.0)
    importance_sampling_level = kwargs.get("importance_sampling_level", "token")

    input_ids = input_ids.unsqueeze(-1)

    # Optional logit softcapping and logit dividing
    if logit_scale_multiply != 0: new_logits = new_logits * logit_scale_multiply
    if logit_scale_divide   != 0: new_logits = new_logits / logit_scale_divide
    if logit_softcapping    != 0: new_logits = new_logits * torch.tanh(new_logits / logit_softcapping)

    new_logits = new_logits.to(torch.float32)
    # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
    if temperature != 1.0: new_logits = new_logits / temperature
    new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
    new = new_x - torch.logsumexp(new_logits, dim = -1)
    # x_i - logsumexp(x_i)
    with torch.no_grad():
        if beta != 0.0:
            assert ref_logits is not None, "ref_logits should not be None when beta != 0.0"

            # Optional logit softcapping and logit dividing
            if logit_scale_multiply != 0: ref_logits = ref_logits * logit_scale_multiply
            if logit_scale_divide   != 0: ref_logits = ref_logits / logit_scale_divide
            if logit_softcapping    != 0: ref_logits = ref_logits * torch.tanh(ref_logits / logit_softcapping)

            ref_logits = ref_logits.to(torch.float32)
            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
            if temperature != 1.0: ref_logits = ref_logits / temperature
            ref_x = torch.gather(ref_logits, dim = -1, index = input_ids).squeeze(-1)
            ref = ref_x - torch.logsumexp(ref_logits, dim = -1)
        pass

        if old_logits is not None:
            # Optional logit softcapping and logit dividing
            if logit_scale_multiply != 0: old_logits = old_logits * logit_scale_multiply
            if logit_scale_divide   != 0: old_logits = old_logits / logit_scale_divide
            if logit_softcapping    != 0: old_logits = old_logits * torch.tanh(old_logits / logit_softcapping)

            old_logits = old_logits.to(torch.float32)
            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
            if temperature != 1.0: old_logits = old_logits / temperature
            old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
            old = old_x - torch.logsumexp(old_logits, dim = -1)
        pass
    pass

    # Reverse KL
    # Note that this is a low variance low bias estimator for the KL divergence as used in GRPO paper
    if beta != 0.0:
        q = self.compressor.sample(new=new, ref=ref, mask=mask)
        log_q = torch.log(q.clamp_min(1e-8))     # q_min > 0이어도 안전하게 clamp
        ref_scaled = ref + log_q

        kl_i = torch.exp(ref_scaled - new) - (ref_scaled - new) - 1.0


    else:
        # set kl_i to a tensor of zeros with the correct shape
        if importance_sampling_level == "sequence":
            kl_i = new.new_zeros(new.size(0), 1)
        else:
            kl_i = torch.zeros_like(new)
    # Full correct reverse KL divergence?? Missing term maybe?
    # kl_i = torch.exp(new) * kl_i

    # Below is forward KL (normal KL)
    # kl_i = torch.exp(old) * (old - new)
    if old_logits is not None:
        log_ratio = new - old
    else:
        log_ratio = new - new.detach()

    if importance_sampling_level == "token":
        log_importance_weights = log_ratio
    elif importance_sampling_level == "sequence":
        log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
        log_importance_weights = log_importance_weights.unsqueeze(-1)
    else:
        raise ValueError(
            f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
            "and 'sequence'."
        )

    coef_1 =  torch.exp(log_importance_weights)

    coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)

    if delta is not None:
        loss_1 = torch.clamp(coef_1, max=delta) * advantages.unsqueeze(1)
    else:
        loss_1 = coef_1 * advantages.unsqueeze(1)
    pass

    # Must detach - otherwise gradients are not propagated correctly!
    # exp(x - x) == 1
    # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)

    loss_2 = coef_2 * advantages.unsqueeze(1)
    loss_i = -torch.min(loss_1, loss_2)
    if beta != 0.0:
        loss_i = loss_i + beta * kl_i

    mask = mask.to(torch.float32)
    n_mask_per_reward = mask.sum(1)

    # https://github.com/huggingface/trl/blob/e8b8499f1f8d76838155b515e414ee98f757d6d5/trl/trainer/grpo_trainer.py#L1624
    if loss_type == "grpo":
        loss = ((loss_i * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
    elif loss_type == "bnpo":
        loss = (loss_i * mask).sum() / mask.sum().clamp(min=1.0)
    elif loss_type == "dr_grpo":
        loss = (loss_i * mask).sum() / (loss_i.size(0) * max_completion_length)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")

    # loss = (loss_i * mask).sum() / mask.sum()



class RandomTrainer(GRPOTrainer):
  def __init__(self, sampler:Sampler, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.sampler = sampler

  def grpo_compute_loss_slow(self,
    ref_logits,
    new_logits,
    old_logits,
    input_ids,
    mask,
    beta,
    advantages,
    **kwargs
):
    # All Unsloth Zoo code licensed under LGPLv3
    # Set defaults for optional arguments
    loss_type = kwargs.get("loss_type", "grpo")
    epsilon_low = kwargs.get("epsilon_low", 0.2)
    epsilon_high = kwargs.get("epsilon_high", 0.2)
    max_completion_length = kwargs.get("max_completion_length", 8192)
    delta = kwargs.get("delta", None)
    temperature = kwargs.get("temperature", 1.0)
    logit_scale_multiply = kwargs.get("logit_scale_multiply", 0.0)
    logit_scale_divide   = kwargs.get("logit_scale_divide", 0.0)
    logit_softcapping    = kwargs.get("logit_softcapping", 0.0)
    importance_sampling_level = kwargs.get("importance_sampling_level", "token")

    input_ids = input_ids.unsqueeze(-1)

    # Optional logit softcapping and logit dividing
    if logit_scale_multiply != 0: new_logits = new_logits * logit_scale_multiply
    if logit_scale_divide   != 0: new_logits = new_logits / logit_scale_divide
    if logit_softcapping    != 0: new_logits = new_logits * torch.tanh(new_logits / logit_softcapping)

    new_logits = new_logits.to(torch.float32)
    # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
    if temperature != 1.0: new_logits = new_logits / temperature
    new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
    new = new_x - torch.logsumexp(new_logits, dim = -1)
    # x_i - logsumexp(x_i)
    with torch.no_grad():
        if beta != 0.0:
            assert ref_logits is not None, "ref_logits should not be None when beta != 0.0"

            # Optional logit softcapping and logit dividing
            if logit_scale_multiply != 0: ref_logits = ref_logits * logit_scale_multiply
            if logit_scale_divide   != 0: ref_logits = ref_logits / logit_scale_divide
            if logit_softcapping    != 0: ref_logits = ref_logits * torch.tanh(ref_logits / logit_softcapping)

            ref_logits = ref_logits.to(torch.float32)
            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
            if temperature != 1.0: ref_logits = ref_logits / temperature
            ref_x = torch.gather(ref_logits, dim = -1, index = input_ids).squeeze(-1)
            ref = ref_x - torch.logsumexp(ref_logits, dim = -1)
        pass

        if old_logits is not None:
            # Optional logit softcapping and logit dividing
            if logit_scale_multiply != 0: old_logits = old_logits * logit_scale_multiply
            if logit_scale_divide   != 0: old_logits = old_logits / logit_scale_divide
            if logit_softcapping    != 0: old_logits = old_logits * torch.tanh(old_logits / logit_softcapping)

            old_logits = old_logits.to(torch.float32)
            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
            if temperature != 1.0: old_logits = old_logits / temperature
            old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
            old = old_x - torch.logsumexp(old_logits, dim = -1)
        pass
    pass

    # Reverse KL
    # Note that this is a low variance low bias estimator for the KL divergence as used in GRPO paper
    if beta != 0.0:
        q = self.compressor.sample(new=new, ref=ref, mask=mask)
        log_q = torch.log(q.clamp_min(1e-8))     # q_min > 0이어도 안전하게 clamp
        ref_scaled = ref + log_q

        kl_i = torch.exp(ref_scaled - new) - (ref_scaled - new) - 1.0
    else:
        # set kl_i to a tensor of zeros with the correct shape
        if importance_sampling_level == "sequence":
            kl_i = new.new_zeros(new.size(0), 1)
        else:
            kl_i = torch.zeros_like(new)
    # Full correct reverse KL divergence?? Missing term maybe?
    # kl_i = torch.exp(new) * kl_i

    # Below is forward KL (normal KL)
    # kl_i = torch.exp(old) * (old - new)
    if old_logits is not None:
        log_ratio = new - old
    else:
        log_ratio = new - new.detach()

    if importance_sampling_level == "token":
        log_importance_weights = log_ratio
    elif importance_sampling_level == "sequence":
        log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
        log_importance_weights = log_importance_weights.unsqueeze(-1)
    else:
        raise ValueError(
            f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
            "and 'sequence'."
        )

    coef_1 =  torch.exp(log_importance_weights)

    coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)

    if delta is not None:
        loss_1 = torch.clamp(coef_1, max=delta) * advantages.unsqueeze(1)
    else:
        loss_1 = coef_1 * advantages.unsqueeze(1)
    pass

    # Must detach - otherwise gradients are not propagated correctly!
    # exp(x - x) == 1
    # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)

    loss_2 = coef_2 * advantages.unsqueeze(1)
    loss_i = -torch.min(loss_1, loss_2)
    if beta != 0.0:
        loss_i = loss_i + beta * kl_i

    mask = mask.to(torch.float32)
    n_mask_per_reward = mask.sum(1)

    # https://github.com/huggingface/trl/blob/e8b8499f1f8d76838155b515e414ee98f757d6d5/trl/trainer/grpo_trainer.py#L1624
    if loss_type == "grpo":
        loss = ((loss_i * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
    elif loss_type == "bnpo":
        loss = (loss_i * mask).sum() / mask.sum().clamp(min=1.0)
    elif loss_type == "dr_grpo":
        loss = (loss_i * mask).sum() / (loss_i.size(0) * max_completion_length)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")

    # loss = (loss_i * mask).sum() / mask.sum()

class TokenTrainer(GRPOTrainer):
  def __init__(self, sampler:SurprisalAwareSampler, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.sampler = sampler

  def grpo_compute_loss_slow(self,
    ref_logits,
    new_logits,
    old_logits,
    input_ids,
    mask,
    beta,
    advantages,
    **kwargs
):
    # All Unsloth Zoo code licensed under LGPLv3
    # Set defaults for optional arguments
    loss_type = kwargs.get("loss_type", "grpo")
    epsilon_low = kwargs.get("epsilon_low", 0.2)
    epsilon_high = kwargs.get("epsilon_high", 0.2)
    max_completion_length = kwargs.get("max_completion_length", 8192)
    delta = kwargs.get("delta", None)
    temperature = kwargs.get("temperature", 1.0)
    logit_scale_multiply = kwargs.get("logit_scale_multiply", 0.0)
    logit_scale_divide   = kwargs.get("logit_scale_divide", 0.0)
    logit_softcapping    = kwargs.get("logit_softcapping", 0.0)
    importance_sampling_level = kwargs.get("importance_sampling_level", "token")

    input_ids = input_ids.unsqueeze(-1)

    # Optional logit softcapping and logit dividing
    if logit_scale_multiply != 0: new_logits = new_logits * logit_scale_multiply
    if logit_scale_divide   != 0: new_logits = new_logits / logit_scale_divide
    if logit_softcapping    != 0: new_logits = new_logits * torch.tanh(new_logits / logit_softcapping)

    new_logits = new_logits.to(torch.float32)
    # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
    if temperature != 1.0: new_logits = new_logits / temperature
    new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
    new = new_x - torch.logsumexp(new_logits, dim = -1)
    # x_i - logsumexp(x_i)
    with torch.no_grad():
        if beta != 0.0:
            assert ref_logits is not None, "ref_logits should not be None when beta != 0.0"

            # Optional logit softcapping and logit dividing
            if logit_scale_multiply != 0: ref_logits = ref_logits * logit_scale_multiply
            if logit_scale_divide   != 0: ref_logits = ref_logits / logit_scale_divide
            if logit_softcapping    != 0: ref_logits = ref_logits * torch.tanh(ref_logits / logit_softcapping)

            ref_logits = ref_logits.to(torch.float32)
            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
            if temperature != 1.0: ref_logits = ref_logits / temperature
            ref_x = torch.gather(ref_logits, dim = -1, index = input_ids).squeeze(-1)
            ref = ref_x - torch.logsumexp(ref_logits, dim = -1)
        pass

        if old_logits is not None:
            # Optional logit softcapping and logit dividing
            if logit_scale_multiply != 0: old_logits = old_logits * logit_scale_multiply
            if logit_scale_divide   != 0: old_logits = old_logits / logit_scale_divide
            if logit_softcapping    != 0: old_logits = old_logits * torch.tanh(old_logits / logit_softcapping)

            old_logits = old_logits.to(torch.float32)
            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
            if temperature != 1.0: old_logits = old_logits / temperature
            old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
            old = old_x - torch.logsumexp(old_logits, dim = -1)
        pass
    pass

    # Reverse KL
    # Note that this is a low variance low bias estimator for the KL divergence as used in GRPO paper
    if beta != 0.0:
        q = self.compressor.sample(new=new, ref=ref, mask=mask)
        log_q = torch.log(q.clamp_min(1e-8))     # q_min > 0이어도 안전하게 clamp
        ref_scaled = ref + log_q

        kl_i = torch.exp(ref_scaled - new) - (ref_scaled - new) - 1.0

    else:
        # set kl_i to a tensor of zeros with the correct shape
        if importance_sampling_level == "sequence":
            kl_i = new.new_zeros(new.size(0), 1)
        else:
            kl_i = torch.zeros_like(new)
    # Full correct reverse KL divergence?? Missing term maybe?
    # kl_i = torch.exp(new) * kl_i

    # Below is forward KL (normal KL)
    # kl_i = torch.exp(old) * (old - new)
    if old_logits is not None:
        log_ratio = new - old
    else:
        log_ratio = new - new.detach()

    if importance_sampling_level == "token":
        log_importance_weights = log_ratio
    elif importance_sampling_level == "sequence":
        log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
        log_importance_weights = log_importance_weights.unsqueeze(-1)
    else:
        raise ValueError(
            f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
            "and 'sequence'."
        )

    coef_1 =  torch.exp(log_importance_weights)

    coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)

    if delta is not None:
        loss_1 = torch.clamp(coef_1, max=delta) * advantages.unsqueeze(1)
    else:
        loss_1 = coef_1 * advantages.unsqueeze(1)
    pass

    # Must detach - otherwise gradients are not propagated correctly!
    # exp(x - x) == 1
    # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)

    loss_2 = coef_2 * advantages.unsqueeze(1)
    loss_i = -torch.min(loss_1, loss_2)
    if beta != 0.0:
        loss_i = loss_i + beta * kl_i

    mask = mask.to(torch.float32)
    n_mask_per_reward = mask.sum(1)

    # https://github.com/huggingface/trl/blob/e8b8499f1f8d76838155b515e414ee98f757d6d5/trl/trainer/grpo_trainer.py#L1624
    if loss_type == "grpo":
        loss = ((loss_i * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
    elif loss_type == "bnpo":
        loss = (loss_i * mask).sum() / mask.sum().clamp(min=1.0)
    elif loss_type == "dr_grpo":
        loss = (loss_i * mask).sum() / (loss_i.size(0) * max_completion_length)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")

    # loss = (loss_i * mask).sum() / mask.sum()
