import copy
import json
import math
import re
import ast
from collections import defaultdict
from typing import Optional
import torch
import transformers
from transformers import Trainer

from trl.models import unwrap_model_for_generation
from packaging import version

from prefix_grouper import PrefixGrouper
from qwenvl.utils import rank0_print
from qwenvl.utils.mem_cache import empty_cache
from qwenvl.utils.device import forward_offload


def split_list(lst, n):
    """
    split a list into chunks
    """
    total_length = len(lst)
    chunk_size = math.ceil(total_length / n)
    result = [lst[i * chunk_size : (i + 1) * chunk_size] for i in range(n)]
    return [chunk for chunk in result]


class GRPOTrainer(Trainer):
    def __init__(self, *args, vllm_model=None, processor=None, judge_model, judge_tokenizer, suffix_end_str: str = "<|im_end|>\n", **kwargs):
        super().__init__(*args, **kwargs)
        # NOTE: use ``cpu_offload`` to save memory
        self.judge_model = judge_model.to("cpu").eval()
        self.judge_tokenizer = judge_tokenizer
        self.vllm_model = vllm_model
        if vllm_model is not None:
            # NOTE: We should save the buffers in order to restore them during wake-up, because
            # load_weights won't load the buffers, and the buffers will be zeroed during sleep(level=2)
            self.vllm_model_buffers = {name: buffer.clone() for name, buffer in self.get_unwrapped_vllm_model(vllm_model).named_buffers()}
            vllm_model.sleep(2)

        self.processor = processor

        self._metrics = defaultdict(list)

        self.kl_beta = self.args.kl_beta
        self.max_new_tokens = self.args.max_new_tokens
        self.sample_temperature = self.args.sample_temperature
        
        ref_model = None
        if self.args.use_ref_kl:
            with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
                ref_model = copy.deepcopy(unwrapped_model)
                ref_model.eval()
                # NOTE: use ``cpu_offload`` to save memory
                ref_model = forward_offload(ref_model, self.model.device)
                ref_model.to("cpu")
        self.ref_model = ref_model
        print("Ref model device: ", self.ref_model.device)
        # NOTE: Used for formatting the responses
        self.suffix_end_str = suffix_end_str
        self.reward_ema = None

    @property
    def force_ref_threshold(self):
        return self.args.force_ref_threshold

    @property
    def relative_ref_threshold(self):
        return self.args.relative_ref_threshold

    @property
    def lp_threshold_len(self):
        return self.args.lp_threshold_len
    
    @property
    def lp_alpha(self):
        return self.args.lp_alpha
    
    @property
    def lp_beta(self):
        return self.args.lp_beta

    @property
    def ema_gamma(self):
        return self.args.ema_gamma

    @staticmethod
    def get_unwrapped_vllm_model(vllm_model):
        if vllm_model is not None:
            return vllm_model.llm_engine.model_executor.driver_worker.model_runner.model
        return None

    def pad_sequence(self, input_ids, batch_first, padding_value, padding_side):
        if padding_side == "left":
            input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
        if padding_side == "left":
            input_ids = torch.flip(input_ids, [1])
        return input_ids

    @torch.inference_mode()
    def generate_vllm(self, model, inputs):
        """
        Generate responses using vLLM
        """
        from vllm import SamplingParams
        self.vllm_model.wake_up()

        with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
            # Update params
            self.get_unwrapped_vllm_model(self.vllm_model).load_weights(
                tuple((name, param.detach().to("cpu", copy=True).contiguous()) for name, param in unwrapped_model.named_parameters()),
            )
            # Reload buffers
            for name, buffer in self.get_unwrapped_vllm_model(self.vllm_model).named_buffers():
                buffer.copy_(self.vllm_model_buffers[name])

        self.get_unwrapped_vllm_model(self.vllm_model).eval()

        inputs['multi_modal_data'] = inputs['multi_modal_data'][0]
        inputs['mm_processor_kwargs'] = inputs['mm_processor_kwargs'][0]
        
        inputs['prompt'] = inputs['prompt'][0]

        if "video" in inputs['multi_modal_data']:
            inputs['multi_modal_data']['video'][0] = inputs['multi_modal_data']['video'][0].cpu().float()
        elif "image" in inputs['multi_modal_data']:
            inputs['multi_modal_data']['image'][0] = inputs['multi_modal_data']['image'][0].cpu().float()
        else:
            raise ValueError("multi_modal_data should contain either 'video' or 'image'")

        res_list = self.vllm_model.generate(
            [inputs],
            sampling_params=SamplingParams(
                self.args.grpo_group_size,
                temperature=self.args.sample_temperature, 
                max_tokens=self.args.max_new_tokens),
            use_tqdm=True,
        )
        responses = [r.text for res in res_list for r in res.outputs]
        responses_ids = [r.token_ids for res in res_list for r in res.outputs]
        self.vllm_model.sleep(2)
        return responses, responses_ids

    def calculate_rewards(self, responses, questions, ref_captions):
        self.judge_model.to(self.model.device)
        try:
            data = []
            q_lengths = []
            
            for index, response in enumerate(responses):
                question_index = index // self.args.grpo_group_size
                q_data = questions[question_index]

                all_questions = []
                lengths = []
                for q_type in ["cinematography", "visual", "frame", "temporal"]:
                    q_list = q_data[q_type]
                    all_questions.extend(q_list)
                    lengths.append(len(q_list))

                data.append({
                    "caption": response,
                    "questions": all_questions
                })
                q_lengths.append(lengths)

            with torch.inference_mode():
                num_split_chunks = 1
                while True:
                    if num_split_chunks > len(data):
                        raise ValueError("Failed even if split size is as small as possible")
                    results = []
                    success = True
                    for sub_data in split_list(data, num_split_chunks):
                        try:
                            results.extend(self.judge_model(sub_data, self.judge_tokenizer))
                        except Exception:
                            success = False
                            break
                    if success is True:
                        break
                    else:
                        # Smaller batch size
                        num_split_chunks = 2 * num_split_chunks

            final_rewards = {k: [] for k in ["cinematography", "visual", "frame", "temporal"]}
            for idx, result in enumerate(results):
                lengths = q_lengths[idx]
                start = 0
                for i, q_type in enumerate(["cinematography", "visual", "frame", "temporal"]):
                    end = start + lengths[i]
                    sub_result = result[start:end]
                    start = end

                    count_0 = sub_result.count(0)
                    count_2 = sub_result.count(2)
                    reward_val = (count_0 + count_2 * 0.5) / len(sub_result)
                    final_rewards[q_type].append(reward_val)

            for k in final_rewards:
                final_rewards[k] = torch.tensor(final_rewards[k])

            # Precision score
            if self.args.use_precision_rew:
                # FIXME: For quick implementation, only batch size=1 is supported on each gpu.
                final_rewards["precision"] = (torch.tensor(
                    self.judge_model([ref_captions[0][0], *responses], self.judge_tokenizer, type="precision")
                ).flatten() + 1) / 10
            else:
                final_rewards["precision"] = torch.tensor([0.0])
        finally:
            self.judge_model.to("cpu")
            empty_cache()

        if self.args.use_weighted_rew:
            weights = {
                "cinematography": 0.15,
                "visual": 0.15,
                "frame": 0.35,
                "temporal": 0.35,
            }
            rank0_print("weighted: ", weights)
            final_rewards["rewards"] = (
                final_rewards["cinematography"] * weights["cinematography"]
                + final_rewards["visual"] * weights["visual"]
                + final_rewards["frame"] * weights["frame"]
                + final_rewards["temporal"] * weights["temporal"]
            )
        else:
            final_rewards["rewards"] = torch.stack(tuple(final_rewards.values())).mean(dim=0)

        if self.args.use_precision_rew:
            # Add precision score
            final_rewards["rewards"] = (1 - self.args.precision_weight) * final_rewards["rewards"] + self.args.precision_weight * final_rewards["precision"]
        return final_rewards

    def calculate_advantages(self, ref_captions, resp_ids, rewards, responses):
        def calculate_gentle_gt_reward(on_policy_rewards: torch.Tensor) -> torch.Tensor:
            mu_on = torch.mean(on_policy_rewards)
            advantages = on_policy_rewards - mu_on
            positive_advantages = advantages[advantages > 0]
            if positive_advantages.numel() == 0:
                gt_reward = mu_on
            else:
                min_pos_adv = torch.min(positive_advantages)
                gt_reward = mu_on + min_pos_adv
            return gt_reward

        resp_ids_len = torch.tensor([len(resp_id) for resp_id in resp_ids])
        assert len(resp_ids_len) == len(rewards)
        penalty = 1 - (1 - self.lp_alpha) * (resp_ids_len / self.lp_threshold_len) ** self.lp_beta
        rewards = rewards * torch.where(penalty >= self.lp_alpha, penalty, self.lp_alpha)
        # NOTE: Whether to include ref_list in inputs is controlled by 2 factors:
        # 1. ``use_gt_ref`` arg.
        # 2. ``force_ref_threshold`` and ``relative_ref_threshold`` (which means the number of ref captions may differ in different samples)
        # NOTE: We use list here rather than using torch.tensor to enable parallel computing, because the tensor sizes may differ in each group
        final_ref_list = []
        final_group_sizes = []
        final_mean_grouped_rewards = []
        final_std_grouped_rewards = []

        grouped_rewards = rewards.split(self.args.grpo_group_size)
        split_size = self.args.grpo_group_size
        responses = [responses[i * split_size : (i + 1) * split_size] for i in range(math.ceil(len(responses) / split_size))]
        assert len(grouped_rewards) == len(ref_captions) == len(responses)
        rewards = []
        # NOTE: We add reward mask here, because when applying ref_threshold, different samples may introduce or not introduce ref reward, causing
        # stuck during DDP training, so we should introduce ref reward in all samples, and then mask them.
        reward_mask = []
        ref_mask = []  # ref mask for importance sampling in gt samples
        for ref_list, grouped_rew, resps in zip(ref_captions, grouped_rewards, responses):
            max_grouped_rew = grouped_rew.max()
            ref_list_ = []
            final_group_size = self.args.grpo_group_size
            grouped_reward_mask = torch.ones(self.args.grpo_group_size, dtype=torch.bool, device=grouped_rew.device)
            grouped_ref_mask = torch.ones(self.args.grpo_group_size, dtype=torch.bool, device=grouped_rew.device)
            if self.args.use_ema_gt_ref:
                ref_list_ = ref_list
                final_group_size = final_group_size + len(ref_list)
                mean_grouped_rew = grouped_rew.mean()
                grouped_ref_mask = torch.cat(
                    [
                        torch.zeros(len(ref_list), dtype=torch.bool, device=grouped_rew.device),
                        grouped_ref_mask,
                    ]
                )
                if (
                    self.reward_ema is None
                    or (not self.args.use_max_ema and mean_grouped_rew >= self.reward_ema)
                    or (self.args.use_max_ema and max_grouped_rew >= self.reward_ema)
                ):
                    # NOTE: This is a placeholder to avoid stuck during DDP training
                    grouped_rew = torch.cat(
                        [
                            torch.tensor([0.0], dtype=grouped_rew.dtype, device=grouped_rew.device).repeat(len(ref_list)),
                            grouped_rew,
                        ]
                    )
                    grouped_reward_mask = torch.cat(
                        [
                            # NOTE: mask the rewards
                            torch.zeros(len(ref_list), dtype=torch.bool, device=grouped_rew.device),
                            grouped_reward_mask,
                        ]
                    )
                else:
                    # NOTE: Add relative ref reward
                    grouped_rew = torch.cat(
                        [
                            (calculate_gentle_gt_reward(grouped_rew) if self.args.use_gentle_gt_ref else max_grouped_rew).repeat(len(ref_list)),
                            grouped_rew,
                        ]
                    )
                    grouped_reward_mask = torch.cat(
                        [
                            torch.ones(len(ref_list), dtype=torch.bool, device=grouped_rew.device),
                            grouped_reward_mask,
                        ]
                    )
            elif self.args.use_gentle_gt_ref:
                ref_list_ = ref_list
                final_group_size = final_group_size + len(ref_list)
                mean_grouped_rew = grouped_rew.mean()
                grouped_ref_mask = torch.cat(
                    [
                        torch.zeros(len(ref_list), dtype=torch.bool, device=grouped_rew.device),
                        grouped_ref_mask,
                    ]
                )
                grouped_rew = torch.cat(
                    [
                        calculate_gentle_gt_reward(grouped_rew).repeat(len(ref_list)),
                        grouped_rew,
                    ]
                )
                grouped_reward_mask = torch.cat(
                    [
                        torch.ones(len(ref_list), dtype=torch.bool, device=grouped_rew.device),
                        grouped_reward_mask,
                    ]
                )
            elif self.args.use_gt_ref:
                ref_list_ = ref_list
                final_group_size = final_group_size + len(ref_list)
                grouped_ref_mask = torch.cat(
                    [
                        torch.zeros(len(ref_list), dtype=torch.bool, device=grouped_rew.device),
                        grouped_ref_mask,
                    ]
                )
                if max_grouped_rew <= self.force_ref_threshold:
                    # NOTE: Add force ref reward
                    grouped_rew = torch.cat(
                        [
                            torch.tensor([1.0], dtype=grouped_rew.dtype, device=grouped_rew.device).repeat(len(ref_list)),
                            grouped_rew,
                        ]
                    )
                    grouped_reward_mask = torch.cat(
                        [
                            torch.ones(len(ref_list), dtype=torch.bool, device=grouped_rew.device),
                            grouped_reward_mask,
                        ]
                    )
                elif (self.force_ref_threshold <= max_grouped_rew) and (max_grouped_rew <= self.relative_ref_threshold):
                    # NOTE: Add relative ref reward
                    grouped_rew = torch.cat(
                        [
                            max_grouped_rew.repeat(len(ref_list)),
                            grouped_rew,
                        ]
                    )
                    grouped_reward_mask = torch.cat(
                        [
                            torch.ones(len(ref_list), dtype=torch.bool, device=grouped_rew.device),
                            grouped_reward_mask,
                        ]
                    )
                else:
                    # NOTE: This is a placeholder to avoid stuck during DDP training
                    grouped_rew = torch.cat(
                        [
                            torch.tensor([0.0], dtype=grouped_rew.dtype, device=grouped_rew.device).repeat(len(ref_list)),
                            grouped_rew,
                        ]
                    )
                    grouped_reward_mask = torch.cat(
                        [
                            # NOTE: mask the rewards
                            torch.zeros(len(ref_list), dtype=torch.bool, device=grouped_rew.device),
                            grouped_reward_mask,
                        ]
                    )

            final_ref_list.append(ref_list_)
            final_group_sizes.append(final_group_size)
            # Compute mean and std
            final_mean_grouped_rewards.append(grouped_rew[grouped_reward_mask].mean())
            final_std_grouped_rewards.append(grouped_rew[grouped_reward_mask].std())
            rewards.append(grouped_rew)
            reward_mask.append(grouped_reward_mask)
            ref_mask.append(grouped_ref_mask)
        rewards = torch.cat(rewards)
        reward_mask = torch.cat(reward_mask)
        ref_mask = torch.cat(ref_mask)
        final_group_sizes = torch.tensor(final_group_sizes, dtype=torch.long, device=rewards.device)
        final_mean_grouped_rewards = torch.stack(final_mean_grouped_rewards).repeat_interleave(final_group_sizes, dim=0)
        final_std_grouped_rewards = torch.stack(final_std_grouped_rewards).repeat_interleave(final_group_sizes, dim=0)
        advantages = (rewards - final_mean_grouped_rewards) / (final_std_grouped_rewards + 1e-4)
        advantages = advantages * reward_mask.to(advantages.dtype)  # NOTE: mask the placeholder advantages
        return {
            "advantages": advantages,
            "final_ref_list": final_ref_list,  # If ``use_gt_ref`` is False, then the list will be empty, else it will be the ref captions.
            "rewards": rewards,  # Rewards after considering the length penalty and ref rewards
            "reward_mask": reward_mask,  # For dynamic reward
            "ref_mask": ref_mask,  # For importance sampling (distinguish between model output and gt)
            "final_group_sizes": final_group_sizes,
            "final_mean_grouped_rewards": final_mean_grouped_rewards,
            "final_std_grouped_rewards": final_std_grouped_rewards,
            "resp_ids_len": resp_ids_len,
        }

    def prepare_inputs_for_policy_model(
        self,
        inputs,
        final_ref_list,
        responses,
        device,
    ):
        # 1. process responses
        resp_list = [[] for _ in range(len(inputs["ref_captions"]))]  # List[List[str]]
        for ref_list, resp_l in zip(final_ref_list, resp_list):
            resp_l.extend([(ref + self.suffix_end_str) for ref in ref_list])

        for index, response in enumerate(responses):
            # NOTE: response index is calculated according to grpo group size
            resp_index = index // self.args.grpo_group_size
            resp_list[resp_index].append(response + self.suffix_end_str)

        completion_inputs = self.processor(
            text=[r for resps in resp_list for r in resps],
            return_tensors="pt",
            padding=True,
            padding_side="right",
            add_special_tokens=False,
        )

        # 2. process prompts
        if 'video' in inputs["multi_modal_data"]:
            prompt_inputs = self.processor(
                text=[inputs["prompt"]],
                images=None,
                videos=inputs["multi_modal_data"]["video"],
                return_tensors="pt",
                padding=True,
                padding_side="right",
                add_special_tokens=False,
            )
            prompt_inputs['second_per_grid_ts'] = [self.processor.image_processor.temporal_patch_size / inputs["mm_processor_kwargs"]['fps'][0]]
            prompt_inputs['pixel_values_videos'] = prompt_inputs['pixel_values_videos'].to(device)
            prompt_inputs['video_grid_thw']= prompt_inputs['video_grid_thw'].to(device)
        else:
            raise ValueError("Should contain a video")

        return prompt_inputs, completion_inputs

    def get_per_token_logp_list(self, logits, prefix_grouper, completion_ids) -> list:
        # NOTE: The last token of the prefix should be changed to the first input token of the suffix
        # NOTE: The new ``suffix_mask`` will include the last prefix token at the start
        prefix_output, prefix_mask, suffix_output, suffix_mask = (
            prefix_grouper.split_output(logits, include_prefix_last=1)
        )
        # # NOTE: Convert all tensors to float will cause high GPU mem
        # suffix_output = suffix_output[:, :-1].float()
        suffix_output = suffix_output[:, :-1]
        suffix_mask = suffix_mask[:, 1:]
        per_token_logp_list = []
        for i in range(len(suffix_output)):
            # Compute logps
            selected_logits = torch.gather(suffix_output[i], dim=-1, index=completion_ids[i].unsqueeze(-1)).squeeze(-1).float()
            # loop to reduce peak mem consumption
            logsumexp_values = torch.stack([torch.logsumexp(lg.float(), dim=-1) for lg in suffix_output[i]])
            per_token_logps = selected_logits - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)
            per_token_logp_list.append(per_token_logps[suffix_mask[i]])
        return per_token_logp_list

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")

        empty_cache()  # Release GPU/NPU memory
        # Generate completions

        assert self.vllm_model, "vllm_model should not be None"
        responses, resp_ids = self.generate_vllm(model, inputs)
        # Caculate Rewards
        rewards_dict = self.calculate_rewards(responses, inputs["questions"], inputs["ref_captions"])
        # Calculate advantages
        adv_data = self.calculate_advantages(inputs["ref_captions"], resp_ids, rewards_dict["rewards"], responses)
        # Prepare inputs
        prompt_inputs, completion_inputs = self.prepare_inputs_for_policy_model(
            inputs=inputs,
            final_ref_list=adv_data["final_ref_list"],
            responses=responses,
            device=model.device,
        )
        prompt_ids = prompt_inputs.pop("input_ids").to(model.device)
        prompt_mask = prompt_inputs.pop("attention_mask").to(model.device)
        completion_ids = completion_inputs["input_ids"].to(model.device)
        completion_mask = completion_inputs["attention_mask"].to(model.device)
        prefix_grouper = PrefixGrouper.from_ungrouped_masks(
            # NOTE: The ``group_info`` can be automatically calculated through masks!
            prefix_mask=prompt_mask,
            suffix_mask=completion_mask,
            group_sizes=adv_data["final_group_sizes"].tolist(),
            device=model.device,
            padding_mode="right",
        )
        prompt_inputs["input_ids"] = prefix_grouper.concat_input(prompt_ids, prompt_mask, completion_ids, completion_mask)
        prompt_inputs["attention_mask"] = prefix_grouper.padding_mask
        # Forward
        res = model(**prompt_inputs, use_cache=False, prefix_grouper=prefix_grouper)
        # Compute loss
        completion_ids = prefix_grouper.convert_padding(completion_ids, completion_mask, padding_mode="right")
        per_token_logp_list = self.get_per_token_logp_list(res.logits, prefix_grouper, completion_ids)
        del res; empty_cache()  # Release GPU
        loss_list = [
            torch.exp(
                per_token_logps
                - (
                    per_token_logps.detach()
                    if not self.args.use_gt_ref_sampling_factor
                    # If the sample is from gt, then apply ``pi_old = 1``, else ``pi_old = pi``
                    else per_token_logps.detach() * mask_value.to(per_token_logps.dtype).to(per_token_logps.device)
                )
            ).mean()
            for per_token_logps, mask_value in zip(per_token_logp_list, adv_data["ref_mask"])
        ]
        rank0_print("Rewards:", adv_data["rewards"].shape, adv_data["rewards"], adv_data["advantages"])  # NOTE: DEBUG
        rank0_print("loss_list: ", loss_list)

        kl_loss = 0.0
        if self.ref_model:
            with torch.inference_mode():
                # ref_res = self.ref_model(**inputs_for_model)
                ref_res = self.ref_model(**prompt_inputs, use_cache=False, prefix_grouper=prefix_grouper)
                empty_cache()
                ref_per_token_logp_list = self.get_per_token_logp_list(ref_res.logits, prefix_grouper, completion_ids)
                del ref_res; empty_cache()  # Release GPU
                assert len(per_token_logp_list) == len(ref_per_token_logp_list)
            # NOTE: Should enable gradient here!!!
            kl_loss = [(torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1).mean() for per_token_logps, ref_per_token_logps in zip(per_token_logp_list, ref_per_token_logp_list)]
            kl_loss = torch.stack(kl_loss)
            kl_loss = kl_loss * adv_data["reward_mask"].to(kl_loss.dtype).to(kl_loss.device)  # NOTE: mask the placeholder kl_loss

        # Compute advantage weighted loss
        loss = torch.stack(loss_list)
        loss = -(loss * adv_data["advantages"].to(loss.device) - self.kl_beta * kl_loss).mean()

        self._metrics["kl_loss"].append(self.accelerator.gather_for_metrics(torch.tensor(kl_loss, device=model.device)).mean().item())
        self._metrics["avg_response_len"].append(self.accelerator.gather_for_metrics(adv_data["resp_ids_len"].to(model.device).to(torch.float)).mean().item())
        self._metrics["avg_cinematography_rewards"].append(self.accelerator.gather_for_metrics(rewards_dict["cinematography"].to(model.device)).mean().item())
        self._metrics["avg_visual_rewards"].append(self.accelerator.gather_for_metrics(rewards_dict["visual"].to(model.device)).mean().item())
        self._metrics["avg_frame_rewards"].append(self.accelerator.gather_for_metrics(rewards_dict["frame"].to(model.device)).mean().item())
        self._metrics["avg_temporal_rewards"].append(self.accelerator.gather_for_metrics(rewards_dict["temporal"].to(model.device)).mean().item())
        self._metrics["avg_precision_rewards"].append(self.accelerator.gather_for_metrics(rewards_dict["precision"].to(model.device)).mean().item())
        avg_rewards = self.accelerator.gather_for_metrics(rewards_dict["rewards"].to(model.device)).mean().item()
        self._metrics["avg_rewards"].append(avg_rewards)
        self._metrics["avg_rewards_std"].append(self.accelerator.gather_for_metrics(rewards_dict["rewards"].view(-1, self.args.grpo_group_size).std(dim=1).repeat_interleave(self.args.grpo_group_size, dim=0).to(model.device)).mean().item())

        if self.args.use_ema_gt_ref:
            if self.reward_ema is None:
                self.reward_ema = avg_rewards
            else:
                self.reward_ema = self.ema_gamma * self.reward_ema + (1 - self.ema_gamma) * avg_rewards
            rank0_print("reward_ema: ", self.reward_ema, "ema_gamma: ", self.ema_gamma)
        return loss

    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
        metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()}  # average the metrics
        logs = {**logs, **metrics}
        if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
            super().log(logs, start_time)
        else:  # transformers<=4.46
            super().log(logs)
        self._metrics.clear()
