# part of the code is adapted from https://github.com/huggingface/trl
from typing import Any, Union, Optional

import torch
import wandb
from accelerate.utils import broadcast_object_list, gather, gather_object
from scipy.special import binom
from torch import nn
from transformers import Trainer
from transformers.utils import is_peft_available
from trl.data_utils import (
    apply_chat_template,
    is_conversational,
    maybe_apply_chat_template,
)
from trl.models import unwrap_model_for_generation
from trl.trainer.utils import pad

from trainer.custom_grpo import CustomGRPOTrainer
from trainer.custom_grpo import nanstd, nanmax, nanmin

if is_peft_available():
    pass

from trl.extras.profiling import profiling_context

# todo add is vllm available check
from vllm import SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from contextlib import nullcontext

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import warnings


def split_tensor_dict(
        tensor_dict: dict[str, Optional[torch.Tensor]], num_chunks: int
) -> list[dict[str, Optional[torch.Tensor]]]:
    """
    Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts.

    Example:
        >>> x = torch.arange(12).reshape(6, 2)
        >>> y = torch.arange(6).reshape(6, 1)
        >>> tensor_dict = {"x": x, "y": y}
        >>> split_tensor_dict(tensor_dict, 3)
        [
            {"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])},
            {"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])},
            {"x": tensor([[ 8,  9], [10, 11]]), "y": tensor([[4], [5]])}
        ]
    """
    first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
    chunk_size = first_tensor.shape[0] // num_chunks
    return [
        {
            key: (
                tensor[i * chunk_size: (i + 1) * chunk_size]
                if tensor is not None
                else None
            )
            for key, tensor in tensor_dict.items()
        }
        for i in range(num_chunks)
    ]


class BonGRPOTrainer(CustomGRPOTrainer):
    def __init__(self, best_k=8, var_redaction=None, clamp_delta=0.2, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.best_k = best_k
        self.var_redaction = None
        self.clamp_delta = clamp_delta

    def _prepare_inputs(
            self, generation_batch: dict[str, Union[torch.Tensor, Any]]
    ) -> dict[str, Union[torch.Tensor, Any]]:
        # Prepares inputs for model training/evaluation by managing completion generation and batch handling.
        # During training:
        #   - Receives the local generation batch (Per-GPU batch size × steps per generation)
        #     from the modified training dataloader instead of the standard local batch
        #   - Generates completions once for the entire generation batch and splits it into batches of size
        #     `per_device_train_batch_size`
        #   - Buffers these completions and returns the appropriate slice for the current accumulation step
        #   - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations)
        # During evaluation:
        #   - The input is treated as a standard local batch (no accumulation, no multiple iterations)
        #   - Completions are generated for each batch without buffering or reuse
        # Returns a single local batch in both cases.

        mode = "train" if self.model.training else "eval"
        if mode == "train":
            generate_every = self.args.steps_per_generation * self.num_iterations
            if self._step % generate_every == 0 or self._buffered_inputs is None:
                # self._buffered_inputs=None can occur when resuming from a checkpoint
                generation_batch = self._generate_and_score_completions(
                    generation_batch
                )
                self._buffered_inputs = split_tensor_dict(
                    generation_batch, self.args.steps_per_generation
                )
            inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
            self._step += 1
        else:
            # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
            # local generation batch == local eval batch
            inputs = self._generate_and_score_completions(generation_batch)
        return inputs

    def bon_scaler(self, rewards, k):
        n, m = rewards.shape  # n prompts, m generations

        # Calculate scale factors for non diagonal elements
        den = binom(m, self.best_k)
        scale = binom(torch.arange(1, m + 1) - 2, self.best_k - 2) / den
        scale = scale.nan_to_num(0)

        # broadcast scales for each element. here each column is scale
        scale = scale.repeat(m, 1).T
        scale = torch.tril(scale, diagonal=-1)

        # add diagonal elements to scales
        diag_coef = binom(torch.arange(1, m + 1) - 1, self.best_k - 1) / den
        diag_coef = diag_coef.nan_to_num(0)
        diag_coef = torch.diag(diag_coef)

        scale = scale + diag_coef

        # convert to same dtype as rewards
        scale = scale.to(rewards.dtype)
        scale = scale.to(rewards.device)

        # calculate bon rewards
        bon_rewards = rewards @ scale
        # bon_rewards = bon_rewards.squeeze()
        return bon_rewards, scale

    def bon_scaler_offpolicy(self, rewards, deltas, k):
        n, m = rewards.shape  # n prompts, m generations
        den = binom(m, k)  # denominator aka number of k-tuples from generations

        # print("deltas.max()=", deltas.abs().max())
        deltas = deltas.clamp(min=-self.clamp_delta, max=self.clamp_delta)
        cum_deltas = (
                torch.cumsum(deltas, dim=1) - deltas
        )  # sum of all deltas for correction

        # calculating C(j-2, k-2) scale for off-diagonal elements
        scale_1 = binom(torch.arange(1, m + 1) - 2, k - 2) / den
        scale_1 = scale_1.nan_to_num(0)

        # Broadcast scales for each element
        scale_1 = scale_1.repeat(m, 1).T
        scale_1 = torch.tril(scale_1, diagonal=-1)
        scale_1 = scale_1.to(rewards.device)

        # Add diagonal elements to scales
        diag_scale_1 = binom(torch.arange(1, m + 1) - 1, k - 1) / den  # C(j-1, k-1)
        diag_scale_1 = diag_scale_1.nan_to_num(0)
        diag_scale_1 = diag_scale_1.to(rewards.device)

        scale_base = scale_1 + diag_scale_1

        # off-policy correction

        # C(j-2, k-2) * (delta_i + delta_j)
        off_diag_term1 = scale_1 * (deltas.view(n, m, 1) + deltas.view(n, 1, m))

        # calculating C(j-3, k-3) scale for off-diagonal elements
        scale_2 = binom(torch.arange(1, m + 1) - 3, k - 3) / den
        scale_2 = scale_2.nan_to_num(0)
        scale_2 = scale_2.repeat(m, 1).T
        scale_2 = torch.tril(scale_2, diagonal=-1)
        scale_2 = scale_2.to(rewards.device)

        # C(j-3, k-3) * (cum_delta_j - delta_i)
        off_diag_term2 = scale_2 * (cum_deltas.view(n, m, 1) - deltas.view(n, 1, m))

        off_diag = off_diag_term1 + off_diag_term2

        diag_scale_2 = binom(torch.arange(1, m + 1) - 2, k - 2) / den  # C(j-2, k-2)
        diag_scale_2 = diag_scale_2.nan_to_num(0)

        diag_scale_1 = diag_scale_1.to(rewards.device)
        diag_scale_2 = diag_scale_2.to(rewards.device)

        # C(i-1, k-1) * (1 + delta_i) + C(i-2, k-2) * cum_delta_i
        diag_term = diag_scale_1 * deltas + diag_scale_2 * cum_deltas
        diag_term = torch.diag_embed(diag_term)

        scale_correction = off_diag + diag_term

        weights = scale_correction + scale_base
        weights = weights.to(rewards.dtype)

        policy_gradient_weights = rewards @ weights
        policy_gradient_weights = policy_gradient_weights.view(rewards.shape)
        policy_gradient_weights = policy_gradient_weights.nan_to_num(0)

        return policy_gradient_weights, weights

    def loo_redaction(self, rewards, scale, k, m):
        b = torch.zeros_like(rewards)
        b1_scale = scale[torch.arange(1, m), torch.arange(1, m)]
        b1_scale[:-1] += (torch.arange(2, m) - 2) * scale[
            torch.arange(1, m - 1), torch.arange(1, m - 1) + 1
        ]
        b[:, 0] = rewards[:, 1:] @ b1_scale
        for i in range(1, m):
            b[:, i] = b[:, i - 1] + (rewards[:, i - 1] - rewards[:, i]) * (
                    scale[i - 1, i - 1] + (i - 1) * scale[i - 1, i]
            )

        return b

    def recalculate_rewards(self, rewards, deltas=None, onpolicy=False):
        squeeze = False

        if rewards.ndim == 1:
            rewards = rewards.unsqueeze(0)
            deltas = deltas.unsqueeze(0)
            squeeze = True
        n, m = rewards.shape  # n prompts, m generations

        # Sort each row independently
        ind = torch.argsort(rewards, dim=1, descending=False)
        sorted_rewards = torch.gather(rewards, 1, ind)

        if onpolicy:
            bon_rewards, scale = self.bon_scaler(sorted_rewards, self.best_k)
        else:
            bon_rewards, scale = self.bon_scaler_offpolicy(
                sorted_rewards, deltas, self.best_k
            )

        if self.var_redaction == "loo":
            b = self.loo_redaction(sorted_rewards, scale, self.best_k, m)
            bon_rewards = bon_rewards - b
        if self.var_redaction == "loo-1":
            b = self.loo_redaction(sorted_rewards, scale, self.best_k - 1, m)
            b = b * self.best_k / m / (self.best_k - 1)
            bon_rewards = bon_rewards - b

        # scatter bon rewards back to original positions
        original_rewards = torch.zeros_like(bon_rewards)

        # print("ind.shape=", ind.shape)
        # print("original_rewards.shape=", original_rewards.shape)
        # print("bon_rewards.shape=", bon_rewards.shape)
        assert len(ind) == len(original_rewards)
        assert len(ind) == len(bon_rewards)
        assert len(ind) == n

        for i in range(n):
            original_rewards[i].scatter_(0, ind[i], bon_rewards[i])

        if squeeze:
            original_rewards = original_rewards.squeeze(0)

        return original_rewards

    def _generate_and_score_completions(
            self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
    ) -> dict[str, Union[torch.Tensor, Any]]:
        device = self.accelerator.device
        mode = "train" if self.model.training else "eval"

        prompts = [x["prompt"] for x in inputs]
        prompts_text = [
            maybe_apply_chat_template(example, self.processing_class)["prompt"]
            for example in inputs
        ]
        prompt_inputs = self.processing_class(
            text=prompts_text,
            return_tensors="pt",
            padding=True,
            padding_side="left",
            add_special_tokens=False,
        )
        prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
        prompt_ids, prompt_mask = (
            prompt_inputs["input_ids"],
            prompt_inputs["attention_mask"],
        )

        if self.max_prompt_length is not None:
            prompt_ids = prompt_ids[:, -self.max_prompt_length:]
            prompt_mask = prompt_mask[:, -self.max_prompt_length:]
        # Generate completions using either vLLM or regular generation
        if self.use_vllm:
            # First, update the vLLM weights if needed
            if self.state.global_step != self._last_loaded_step:
                self._move_model_to_vllm()
                self._last_loaded_step = self.state.global_step

            # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
            if self.vllm_mode == "server":
                all_prompts_text = gather_object(prompts_text)
                if self.accelerator.is_main_process:
                    # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
                    # num_generations outputs for each one. This is faster than generating outputs for each duplicate
                    # prompt individually.
                    ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
                    with profiling_context(self, "vLLM.generate"):
                        completion_ids = self.vllm_client.generate(
                            prompts=ordered_set_of_prompts,
                            n=self.num_generations,
                            repetition_penalty=self.repetition_penalty,
                            temperature=self.temperature,
                            top_p=self.top_p,
                            top_k=-1 if self.top_k is None else self.top_k,
                            min_p=0.0 if self.min_p is None else self.min_p,
                            max_tokens=self.max_completion_length,
                            guided_decoding_regex=self.guided_decoding_regex,
                        )
                else:
                    completion_ids = [None] * len(all_prompts_text)
                # Broadcast the completions from the main process to all processes, ensuring each process receives its
                # corresponding slice.
                completion_ids = broadcast_object_list(completion_ids, from_process=0)
                process_slice = slice(
                    self.accelerator.process_index * len(prompts),
                    (self.accelerator.process_index + 1) * len(prompts),
                )
                completion_ids = completion_ids[process_slice]

            # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
            elif self.vllm_mode == "colocate":
                if self.guided_decoding_regex:
                    guided_decoding = GuidedDecodingParams(
                        backend="outlines", regex=self.guided_decoding_regex
                    )
                else:
                    guided_decoding = None
                sampling_params = SamplingParams(
                    n=1,  # vLLM on each GPU generates only 1 in colocate mode
                    repetition_penalty=self.repetition_penalty,
                    temperature=self.temperature,
                    top_p=self.top_p,
                    top_k=-1 if self.top_k is None else self.top_k,
                    min_p=0.0 if self.min_p is None else self.min_p,
                    max_tokens=self.max_completion_length,
                    guided_decoding=guided_decoding,
                )

                if self.vllm_tensor_parallel_size > 1:
                    # Gather prompts from all ranks in the TP group and flatten.
                    # Each rank starts with its own prompts; after gathering, all ranks see the full group set.
                    orig_size = len(prompts_text)
                    gathered_prompts = [
                        None for _ in range(self.vllm_tensor_parallel_size)
                    ]
                    torch.distributed.all_gather_object(
                        gathered_prompts, prompts_text, group=self.tp_group
                    )
                    all_prompts_text = [
                        p for sublist in gathered_prompts for p in sublist
                    ]
                else:
                    all_prompts_text = prompts_text

                with profiling_context(self, "vLLM.generate"):
                    all_outputs = self.llm.generate(
                        all_prompts_text,
                        sampling_params=sampling_params,
                        use_tqdm=False,
                    )

                completion_ids = [
                    output.token_ids
                    for outputs in all_outputs
                    for output in outputs.outputs
                ]

                if self.vllm_tensor_parallel_size > 1:
                    # Slice completions for this rank within its TP group.
                    # Each rank generates all outputs — we keep only our share.
                    local_rank_in_group = torch.distributed.get_rank(
                        group=self.tp_group
                    )
                    tp_slice = slice(
                        local_rank_in_group * orig_size,
                        (local_rank_in_group + 1) * orig_size,
                    )
                    completion_ids = completion_ids[tp_slice]

            # Pad the completions, and concatenate them with the prompts
            completion_ids = [
                torch.tensor(ids, device=device) for ids in completion_ids
            ]
            completion_ids = pad(
                completion_ids, padding_value=self.processing_class.pad_token_id
            )
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        else:
            # Regular generation path
            with unwrap_model_for_generation(
                    self.model_wrapped,
                    self.accelerator,
                    gather_deepspeed3_params=self.args.ds3_gather_for_generation,
            ) as unwrapped_model:
                with (
                    FSDP.summon_full_params(self.model_wrapped, recurse=False)
                    if self.is_fsdp_enabled
                    else nullcontext()
                ):
                    try:
                        prompt_completion_ids = unwrapped_model.generate(
                            prompt_ids,
                            attention_mask=prompt_mask,
                            generation_config=self.generation_config,
                        )
                    except Exception as e:
                        # print(inputs)
                        # print(self.generation_config)
                        raise e

            # Compute prompt length and extract completion ids
            prompt_length = prompt_ids.size(1)
            prompt_ids = prompt_completion_ids[:, :prompt_length]
            completion_ids = prompt_completion_ids[:, prompt_length:]

        # Decode completions first to modify them
        completions_text = self.processing_class.batch_decode(
            completion_ids, skip_special_tokens=True
        )
        processed_completions = self.process_completions(completions_text, inputs)
        completions_text = processed_completions["completions_text"]

        for k in processed_completions:
            if k != "completions_text":
                self._metrics[k].append(processed_completions[k])

        # Encode modified completions back to ids
        modified_completion_inputs = self.processing_class(
            completions_text, return_tensors="pt", padding=True, add_special_tokens=True
        ).to(device)

        completion_ids = modified_completion_inputs["input_ids"]

        if len(completion_ids[0]) == 0:
            # TODO look into this
            completion_ids = torch.full(
                (completion_ids.size(0), 1),
                self.processing_class.eos_token_id,
                dtype=torch.long,
                device=device,
            )

        # Remove BOS token if present
        if len(completion_ids[0]) > 0 and completion_ids[0][0] == self.processing_class.bos_token_id:
            completion_ids = completion_ids[:, 1:]
            if "attention_mask" in modified_completion_inputs:
                modified_completion_inputs["attention_mask"] = (
                    modified_completion_inputs["attention_mask"][:, 1:]
                )

        # Concatenate prompt_ids with modified completion_ids
        prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)

        # Mask everything after the first EOS token
        is_eos = completion_ids == self.processing_class.eos_token_id
        eos_idx = torch.full(
            (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
        )
        eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
        sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
            is_eos.size(0), -1
        )
        completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

        # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
        # to re-tokenize completions if the reward is computed from tokens.
        completion_ids_list = [
            [id.item() for id, m in zip(row, mask_row) if m]
            for row, mask_row in zip(completion_ids, completion_mask)
        ]

        # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
        completion_lengths = completion_mask.sum(1)

        # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
        if self.mask_truncated_completions:
            truncated_completions = ~is_eos.any(dim=1)
            completion_mask = (
                    completion_mask * (~truncated_completions).unsqueeze(1).int()
            )

        # Concatenate prompt_mask with completion_mask for logit computation
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)  # (B, P+C)

        logits_to_keep = completion_ids.size(
            1
        )  # we only need to compute the logits for the completion tokens
        batch_size = (
            self.args.per_device_train_batch_size
            if mode == "train"
            else self.args.per_device_eval_batch_size
        )

        with torch.no_grad():
            # When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps
            # old_per_token_logps == per_token_logps, so we can skip it's computation here, and use
            # per_token_logps.detach() instead.
            if (
                    self.num_iterations > 1
                    or self.args.steps_per_generation
                    > self.args.gradient_accumulation_steps
            ):
                old_per_token_logps = self._get_per_token_logps(
                    self.model,
                    prompt_completion_ids,
                    attention_mask,
                    logits_to_keep,
                    batch_size,
                )
            else:
                old_per_token_logps = None

        # Decode the generated completions
        completions_text = self.processing_class.batch_decode(
            completion_ids, skip_special_tokens=True
        )
        if is_conversational(inputs[0]):
            completions = []
            for prompt, completion in zip(prompts, completions_text):
                bootstrap = (
                    prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
                )
                completions.append(
                    [{"role": "assistant", "content": bootstrap + completion}]
                )
        else:
            completions = completions_text

        rewards_per_func = torch.zeros(
            len(prompts), len(self.reward_funcs), device=device
        )

        # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
        keys = [
            key
            for key in inputs[0]
            if key not in ["prompt", "completion", "completion_ids"]
        ]
        reward_kwargs = {key: [example[key] for example in inputs] for key in keys}

        for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
                zip(
                    self.reward_funcs,
                    self.reward_processing_classes,
                    self.reward_func_names,
                )
        ):
            with profiling_context(self, reward_func_name):
                if isinstance(
                        reward_func, nn.Module
                ):  # Module (no PretrainedModel) for compat with compiled models
                    if is_conversational(inputs[0]):
                        messages = [
                            {"messages": p + c} for p, c in zip(prompts, completions)
                        ]
                        texts = [
                            apply_chat_template(x, reward_processing_class)["text"]
                            for x in messages
                        ]
                    else:
                        texts = [p + c for p, c in zip(prompts, completions)]
                    reward_inputs = reward_processing_class(
                        text=texts,
                        return_tensors="pt",
                        padding=True,
                        padding_side="right",
                        add_special_tokens=False,
                    )
                    reward_inputs = super()._prepare_inputs(reward_inputs)
                    with torch.inference_mode():
                        rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
                            :, 0
                        ]  # Shape (B*G,)
                else:
                    output_reward_func = reward_func(
                        prompts=prompts,
                        completions=completions,
                        completion_ids=completion_ids_list,
                        **reward_kwargs,
                    )
                    # Convert None values to NaN
                    output_reward_func = [
                        reward if reward is not None else torch.nan
                        for reward in output_reward_func
                    ]

                    rewards_per_func[:, i] = torch.tensor(
                        output_reward_func, dtype=torch.float32, device=device
                    )

        # If all reward functions return None for a given row, issue a detailed warning
        if torch.isnan(rewards_per_func).all(dim=1).any():
            nan_row_idx = (
                torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
            )
            row_reward_kwargs = {
                key: value[nan_row_idx] for key, value in reward_kwargs.items()
            }
            row_reward_kwargs["prompt"] = prompts[nan_row_idx]
            row_reward_kwargs["completion"] = completions[nan_row_idx]
            warnings.warn(
                f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
                "Please ensure that at least one reward function returns a valid reward."
            )

        # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
        # completions may be distributed across processes
        rewards_per_func = gather(rewards_per_func)

        # Apply weights to each reward function's output and sum
        rewards = (
                rewards_per_func * self.reward_weights.to(device).unsqueeze(0)
        ).nansum(dim=1)

        # recalculate rewards
        if mode == "train" and self.num_iterations == 1:
            rewards = self.recalculate_rewards(rewards.view(-1, self.num_generations), onpolicy=True)
            rewards = rewards.view(-1)

        # Compute grouped-wise rewards
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
        is_std_zero = torch.isclose(
            std_grouped_rewards, torch.zeros_like(std_grouped_rewards)
        )

        # Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
            self.num_generations, dim=0
        )
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(
            self.num_generations, dim=0
        )

        advantages = self.get_advantages(
            rewards, mean_grouped_rewards, std_grouped_rewards, self.num_generations
        )

        # advantages = rewards - mean_grouped_rewards
        # if self.scale_rewards:
        #     advantages = advantages / (std_grouped_rewards + 1e-4)

        # Slice to keep only the local part of the data
        process_slice = slice(
            self.accelerator.process_index * len(prompts),
            (self.accelerator.process_index + 1) * len(prompts),
        )
        all_process_advantages = (
            advantages.clone()
        )  # keep the aggregated advantages for logging
        advantages = advantages[process_slice]

        # Log the metrics
        if mode == "train":
            self.state.num_input_tokens_seen += (
                self.accelerator.gather(attention_mask.sum()).sum().item()
            )
        self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]

        # Log completion lengths, mean, min, max
        agg_completion_lengths = self.accelerator.gather(completion_lengths)
        self._metrics[mode]["completions/mean_length"].append(
            agg_completion_lengths.float().mean().item()
        )
        self._metrics[mode]["completions/min_length"].append(
            agg_completion_lengths.float().min().item()
        )
        self._metrics[mode]["completions/max_length"].append(
            agg_completion_lengths.float().max().item()
        )

        # Identify sequences that terminated with EOS and log their lengths
        agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
        term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
        clipped_completions_ratio = 1 - len(term_completion_lengths) / len(
            agg_completion_lengths
        )
        self._metrics[mode]["completions/clipped_ratio"].append(
            clipped_completions_ratio
        )
        if (
                len(term_completion_lengths) == 0
        ):  # edge case where no terminated sequences are found
            term_completion_lengths = torch.zeros(1, device=device)
        self._metrics[mode]["completions/mean_terminated_length"].append(
            term_completion_lengths.float().mean().item()
        )
        self._metrics[mode]["completions/min_terminated_length"].append(
            term_completion_lengths.float().min().item()
        )
        self._metrics[mode]["completions/max_terminated_length"].append(
            term_completion_lengths.float().max().item()
        )

        if (
                self.accelerator.is_main_process
                and self.log_completions
                and self.state.global_step % self.args.logging_steps == 0
                and "wandb" in self._original_report_to
        ):

            if wandb.run is not None and self.accelerator.is_main_process:
                import pandas as pd

                table = {
                    "step": [str(self.state.global_step)] * len(rewards),
                    "prompt": gather_object(prompts_text),
                    "completion": gather_object(completions_text),
                    "reward": rewards.tolist(),
                    "ground_truth": [dp["ground_truth"] for dp in inputs],
                }

                df = pd.DataFrame(table)
                # Create a table with a unique name for each step
                table_name = f"completions_step_{self.state.global_step}"

                # STORE instead of logging immediately
                self._pending_tables[table_name] = wandb.Table(dataframe=df)

        # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
        for i, reward_func_name in enumerate(self.reward_func_names):
            mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
            self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
            std_rewards = nanstd(rewards_per_func[:, i]).item()
            self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)
        self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
        self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
        self._metrics[mode]["frac_reward_zero_std"].append(
            is_std_zero.float().mean().item()
        )

        # Log prompt and completion texts
        self._textual_logs["prompt"].extend(gather_object(prompts_text))
        self._textual_logs["completion"].extend(gather_object(completions_text))
        for i, name in enumerate(self.reward_func_names):
            self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
        self._textual_logs["advantages"].extend(all_process_advantages.tolist())

        # print("prep rewards", rewards)
        # print("prep rewards shape", rewards.shape)
        return {
            "prompt_ids": prompt_ids,
            "prompt_mask": prompt_mask,
            "completion_ids": completion_ids,
            "completion_mask": completion_mask,
            "advantages": advantages,
            "old_per_token_logps": old_per_token_logps,
            "rewards": rewards,
        }

    def _compute_loss(self, model, inputs):
        # Compute the per-token log probabilities for the model
        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = (
            inputs["completion_ids"],
            inputs["completion_mask"],
        )

        # print("loss rewards", inputs["rewards"])
        # print("loss rewards shape", inputs["rewards"].shape)
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(
            1
        )  # we only need to compute the logits for the completion tokens

        per_token_logps = self._get_per_token_logps(
            model, input_ids, attention_mask, logits_to_keep
        )

        # Compute the KL divergence between the model and the reference model
        if self.beta != 0.0:
            with torch.no_grad():
                if self.ref_model is not None:
                    ref_per_token_logps = self._get_per_token_logps(
                        self.ref_model, input_ids, attention_mask, logits_to_keep
                    )
                else:
                    with self.accelerator.unwrap_model(self.model).disable_adapter():
                        ref_per_token_logps = self._get_per_token_logps(
                            self.model, input_ids, attention_mask, logits_to_keep
                        )
            per_token_kl = (
                    torch.exp(ref_per_token_logps - per_token_logps)
                    - (ref_per_token_logps - per_token_logps)
                    - 1
            )

        # compute entropy of generations to increase diversity
        if self.entropy_coef != 0.0:
            per_token_entropy = -torch.exp(per_token_logps) * per_token_logps

        # Compute the loss
        advantages = inputs["advantages"]

        # When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps
        # old_per_token_logps == per_token_logps, so we can skip it's computation
        # (see _generate_and_score_completions) and use per_token_logps.detach() instead.
        old_per_token_logps = (
            per_token_logps.detach()
            if inputs["old_per_token_logps"] is None
            else inputs["old_per_token_logps"]
        )
        coef_1 = torch.exp(per_token_logps - old_per_token_logps)
        coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)

        # Calculate deltas by aggregating token probabilities for each generation

        if self.num_iterations > 1:
            rewards = inputs["rewards"].view(-1, self.num_generations)
            masked_logps = per_token_logps * completion_mask
            masked_old_logps = old_per_token_logps * completion_mask
            gen_logps = masked_logps.sum(dim=1, keepdim=True)
            gen_old_logps = masked_old_logps.sum(dim=1, keepdim=True)
            deltas = torch.exp(gen_logps - gen_old_logps) - 1
            deltas = deltas.view(rewards.shape).detach()

            new_rewards = self.recalculate_rewards(rewards, deltas, onpolicy=False)
            new_rewards = new_rewards.view(-1)
            mean_grouped_rewards = new_rewards.view(-1, self.num_generations).mean(
                dim=1
            )
            std_grouped_rewards = new_rewards.view(-1, self.num_generations).std(dim=1)

            # Normalize the rewards to compute the advantages
            mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
                self.num_generations, dim=0
            )
            std_grouped_rewards = std_grouped_rewards.repeat_interleave(
                self.num_generations, dim=0
            )

            new_advantages = (new_rewards - mean_grouped_rewards) / (
                    std_grouped_rewards + 1e-4
            )

            advantages = new_advantages
            # print(inputs["prompt_ids"])

        # Two-sided clipping
        if self.args.delta is not None:
            coef_1 = torch.clamp(coef_1, max=self.args.delta)

        per_token_loss1 = coef_1 * advantages.unsqueeze(1)
        per_token_loss2 = coef_2 * advantages.unsqueeze(1)
        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

        if self.beta != 0.0:
            per_token_loss = per_token_loss + self.beta * per_token_kl

        if self.entropy_coef != 0.0:
            per_token_loss = per_token_loss + self.entropy_coef * per_token_entropy

        if self.loss_type == "grpo":
            loss = (
                    (per_token_loss * completion_mask).sum(-1)
                    / completion_mask.sum(-1).clamp(min=1.0)
            ).mean()
        elif self.loss_type == "bnpo":
            loss = (
                           per_token_loss * completion_mask
                   ).sum() / completion_mask.sum().clamp(min=1.0)
        elif self.loss_type == "dr_grpo":
            loss = (per_token_loss * completion_mask).sum() / (
                    per_token_loss.size(0) * self.max_completion_length
            )
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")

        # Log the metrics
        mode = "train" if self.model.training else "eval"

        if self.beta != 0.0:
            mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
            self._metrics[mode]["kl"].append(
                self.accelerator.gather(mean_kl).nanmean().item()
            )

        # Compute the clipped probability ratios
        is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
        is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (
                advantages.unsqueeze(1) > 0
        )
        is_region_clipped = is_low_clipped | is_high_clipped

        low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum()
        high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
        clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()

        gathered_low_clip = self.accelerator.gather(low_clip)
        self._metrics[mode]["clip_ratio/low_mean"].append(
            gathered_low_clip.nanmean().item()
        )
        self._metrics[mode]["clip_ratio/low_min"].append(
            nanmin(gathered_low_clip).item()
        )
        gathered_high_clip = self.accelerator.gather(high_clip)
        self._metrics[mode]["clip_ratio/high_mean"].append(
            gathered_high_clip.nanmean().item()
        )
        self._metrics[mode]["clip_ratio/high_max"].append(
            nanmax(gathered_high_clip).item()
        )
        gathered_clip_ratio = self.accelerator.gather(clip_ratio)
        self._metrics[mode]["clip_ratio/region_mean"].append(
            gathered_clip_ratio.nanmean().item()
        )
        return loss
