# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import gc
import textwrap
import warnings
from collections import defaultdict
from contextlib import nullcontext
from typing import Any, Callable, Optional, Sized, Union
from unittest.mock import patch
from functools import partial

import torch
import torch.utils.data
import math
import torch.optim as optim
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import transformers
from accelerate.utils import broadcast_object_list, gather_object
from accelerate.utils.other import is_compiled_module
from datasets import Dataset, IterableDataset
from packaging import version
from torch import nn
from torch.utils.data import Sampler
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoProcessor,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainerCallback,
    is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available

from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.import_utils import is_vllm_available
from trl.models import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
from mimogpt.infer.eval_indices_recon_render import ImageEval
from mimogpt.infer.eval_indices_recon_render import parse_args_from_yaml

# from trainer import Trainer

if is_peft_available():
    from peft import PeftConfig, get_peft_model

if is_vllm_available():
    from vllm import LLM, SamplingParams

if is_wandb_available():
    import wandb

from selftokmodel import SelftokModel
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


def lr_linear_early_drop_with_warm_up(x, *, warm_up_steps=45, convert_steps=500, total_steps=2000, max_lr=5e-6, convert_lr=1e-6, min_lr=2e-7):
    # return the factor
    max_factor = 1
    min_factor = min_lr / max_lr
    convert_factor = convert_lr / max_lr

    if x < warm_up_steps:
        # warm up steps
        k = max_factor / warm_up_steps
        lr = k * x
    elif warm_up_steps <= x < convert_steps:
        # high lr stage
        k = (convert_factor - max_factor) / (convert_steps - warm_up_steps)
        lr = k * x - k * warm_up_steps + max_factor
    else:
        k = (min_factor - convert_factor) / (total_steps - convert_steps)
        lr = k * x - k * convert_steps + convert_factor
    
    return lr

class RepeatRandomSampler(Sampler):

    def __init__(
        self,
        data_source: Sized,
        mini_repeat_count: int,
        batch_size: int = 1,
        repeat_count: int = 1,
        seed: Optional[int] = None,
    ):
        self.data_source = data_source
        self.mini_repeat_count = mini_repeat_count
        self.batch_size = batch_size
        self.repeat_count = repeat_count
        self.num_samples = len(data_source)
        self.seed = seed
        self.generator = torch.Generator()  # Create a local random generator
        if seed is not None:
            self.generator.manual_seed(seed)

    def __iter__(self):
        # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
        indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()

        #    [2, 4, 3, 1, 0, 6, 5]
        # -> [[2, 4, 3], [1, 0, 6], [5]]  (batch_size = 3)
        indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]

        #    [[2, 4, 3], [1, 0, 6], [5]]
        # -> [[2, 4, 3], [1, 0, 6]]
        indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]

        for chunk in indexes:
            for _ in range(self.repeat_count):
                for index in chunk:
                    for _ in range(self.mini_repeat_count):
                        yield index

    def __len__(self) -> int:
        return self.num_samples * self.mini_repeat_count * self.repeat_count


class GRPOTrainer(Trainer):
    
    _tag_names = ["trl", "grpo"]

    def __init__(
        self,
        model: Union[str, PreTrainedModel],
        reward_funcs: Union[RewardFunc, list[RewardFunc]],
        args: GRPOConfig = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
        processing_class: Optional[PreTrainedTokenizerBase] = None,
        reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
        peft_config: Optional["PeftConfig"] = None,
    ):
        self.args = args
        model_name = model
        if args is None:
            model_name = model if isinstance(model, str) else model.config._name_or_path
            model_name = model_name.split("/")[-1]
            args = GRPOConfig(f"{model_name}-GRPO")
        
        model_init_kwargs = args.model_init_kwargs or {}
        if isinstance(model, str):
            model_id = model
            torch_dtype = model_init_kwargs.get("torch_dtype")
            if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
                pass  # torch_dtype is already a torch.dtype or "auto" or None
            elif isinstance(torch_dtype, str):  # it's a str, but not "auto"
                torch_dtype = getattr(torch, torch_dtype)
                model_init_kwargs["torch_dtype"] = torch_dtype
            else:
                raise ValueError(
                    "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
                    f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
                )
            # Disable caching if gradient checkpointing is enabled (not supported)
            model_init_kwargs["use_cache"] = (
                False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
            )
            model_init_kwargs["use_cache"] = True
            print(model_init_kwargs)
            model = SelftokModel.from_pretrained(model_name, **model_init_kwargs)
        else:
            model_id = model.config._name_or_path
            if args.model_init_kwargs is not None:
                raise ValueError(
                    "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
                    "This argument can only be used when the `model` argument is a string."
                )

        if peft_config is not None:
            print('set peft!!!!!!!!')
            model = get_peft_model(model, peft_config)

        self.ref_model = SelftokModel.from_pretrained(model_name, **model_init_kwargs)
        self.ref_model.eval()
        for n,p in self.ref_model.named_parameters():
            p.require_grad = False
        
        # Processing class
        if processing_class is None:
            processing_class = AutoTokenizer.from_pretrained(self.args.llama_tokenizer_path, padding_side="left")
        
        # Reward functions
        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]
        self.reward_funcs = reward_funcs

        # Reward processing class
        if reward_processing_classes is None:
            reward_processing_classes = [None] * len(reward_funcs)
        elif not isinstance(reward_processing_classes, list):
            reward_processing_classes = [reward_processing_classes]
        else:
            if len(reward_processing_classes) != len(reward_funcs):
                raise ValueError("The number of reward processing classes must match the number of reward functions.")

        for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
            if isinstance(reward_func, PreTrainedModel):
                if reward_processing_class is None:
                    reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
                if reward_processing_class.pad_token_id is None:
                    reward_processing_class.pad_token = reward_processing_class.eos_token
                # The reward model computes the reward for the latest non-padded token in the input sequence.
                # So it's important to set the pad token ID to the padding token ID of the processing class.
                reward_func.config.pad_token_id = reward_processing_class.pad_token_id
                reward_processing_classes[i] = reward_processing_class
        self.reward_processing_classes = reward_processing_classes

        # Data collator
        def data_collator(features):  # No data collation is needed in GRPO
            return features

        # = 𝜇 in the GRPO paper
        self.num_iterations = args.num_iterations  
        
        # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
        self._step = 0
        
        # Training arguments
        self.max_prompt_length = args.max_prompt_length
        self.max_completion_length = args.max_completion_length  # = |o_i| in the GRPO paper
        self.num_generations = args.num_generations  # = G in the GRPO paper
        self.use_vllm = args.use_vllm

        # loss arguments
        self.beta = args.beta
        self.epsilon_low = args.epsilon_low
        self.epsilon_high = args.epsilon_high

        # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
        # `_get_train_sampler` and `_prepare_inputs`.
        self._buffered_inputs = [None] * args.gradient_accumulation_steps
        
        # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
        # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
        # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
        # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
        # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
        # This acts as a flag to indicate that the warning has already been issued.
        model.warnings_issued["estimate_tokens"] = True

        # Initialize the metrics
        self._metrics = defaultdict(list)

        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            callbacks=callbacks,
            optimizers=optimizers,
        )

        if self.use_vllm:
            if not is_vllm_available():
                raise ImportError(
                    "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
                    "`pip install vllm` to use it."
                )

            if self.accelerator.is_main_process:
                vllm_device = self.args.vllm_device
                if vllm_device == "auto":
                    vllm_device = f"cuda:{self.accelerator.num_processes}"  # take the next GPU idx
                # Check that the requested device is available
                if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
                    raise ValueError(
                        f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
                        "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
                        "value lower than the number of GPUs available on your machine—typically, reducing it by one "
                        f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
                    )
                # Check that the requested device is not also used for training
                if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
                    warnings.warn(
                        f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
                        "behavior. It is recommended to use a dedicated device for vLLM."
                    )
                # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
                # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
                # setting (profiling_patch).
                world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
                profiling_patch = patch(
                    "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
                )
                with world_size_patch, profiling_patch:
                    self.llm = LLM(
                        model=model.name_or_path,
                        device=vllm_device,
                        gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
                        # dtype=self.args.vllm_dtype,
                        # # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
                        # # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
                        # # This is particularly useful here because we generate completions from the same prompts.
                        # enable_prefix_caching=True,
                        # max_model_len=self.args.vllm_max_model_len,
                    )
                self.sampling_params = SamplingParams(
                    n=self.num_generations,
                    temperature=args.temperature,
                    max_tokens=self.max_completion_length,
                )

            self._last_loaded_step = 0  # tag to avoid useless loading during grad checkpointing

            # When using vLLM, the main process is responsible for loading the model weights. This can cause process
            # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
            # synchronize all processes after vLLM has been fully initialized.
            self.accelerator.wait_for_everyone()
        else:
            self.generation_config = GenerationConfig(
                max_new_tokens=self.max_completion_length,
                do_sample=True,
                temperature=args.temperature,
                num_return_sequences=self.num_generations,
                pad_token_id=processing_class.pad_token_id,
            )

        # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
        # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
        # self.model_accepts_loss_kwargs to False to enable scaling.
        self.model_accepts_loss_kwargs = False

        # Add tags to the model
        self.model.add_model_tags(self._tag_names)

        if self.ref_model is not None:
            if self.is_deepspeed_enabled:
                self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
            elif self.is_fsdp_enabled:
                self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
                
        if args.sync_ref_model:
            self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))

        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, PreTrainedModel):
                self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
        
        self.set_special_tokens()
        self.set_model()

    def set_special_tokens(self):
        tokenizer_path = self.args.llama_tokenizer_path
        
        TEXT_VOCAB_SIZE = self.args.text_vocab_size
        IMG_VOCAB_SIZE = self.args.image_vocab_size
        self.special_tokens_id = {
            "bos": 128000,
            "eos": 128001,
            "pad": 128002,
            "boi": TEXT_VOCAB_SIZE + IMG_VOCAB_SIZE,
            "eoi": TEXT_VOCAB_SIZE + IMG_VOCAB_SIZE + 1,
            "cfg": TEXT_VOCAB_SIZE + IMG_VOCAB_SIZE + 2,
            "rep": TEXT_VOCAB_SIZE + IMG_VOCAB_SIZE + 3,
            "ignore": -100,
        }
        
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.tokenizer.pad_token_id = 128002
        
        self.bos = torch.LongTensor([self.special_tokens_id['bos']]).unsqueeze(0).to(self.accelerator.device)
        self.eos = torch.LongTensor([self.special_tokens_id['eos']]).unsqueeze(0).to(self.accelerator.device)
        self.boi = torch.LongTensor([self.special_tokens_id['boi']]).unsqueeze(0).to(self.accelerator.device)
        self.eoi = torch.LongTensor([self.special_tokens_id['eoi']]).unsqueeze(0).to(self.accelerator.device)
        
        self.system_instruction = 'Please generate an image.'
        self.system_instruction = self.tokenizer(
            self.system_instruction,
            return_tensors="pt",
            add_special_tokens=False,
        ).to(self.accelerator.device)
        self.system_ids = self.system_instruction['input_ids']
        self.system_mask = self.system_instruction['attention_mask']
        
        self.task_type = self.args.task_type
        
        self.guidance_scale = self.args.guidance_scale
        self.generate_with_cfg = self.args.generate_with_cfg
        self.set_epsilon = self.args.set_epsilon

    def set_model(self, model=None):
        cfg = parse_args_from_yaml(self.args.selftok_config)
        
        # selftok tokenizer
        if self.args.max_completion_length == 512:
            self.image_eval = ImageEval(cfg=cfg, ckpt_path=self.args.selftok_tokenizer_path, port=self.args.port)
        else:
            raise NotImplementedError()

        if self.task_type=='t2i':
            if 'geneval score' in self.args.reward_list:
                from evaluate_image import GenEvalInf
                self.reward_model = None
                self.reward_model2 = GenEvalInf(weight=self.args.weight, device=self.accelerator.device)

            elif 'dpg score' in self.args.reward_list:
                self.reward_model = None
                self.reward_model2 =None
            
                if 'mplug' in self.args.reward_model_path:
                    from evaluate_image import DPGEval
                    self.dpg_eval = DPGEval(use_api=self.args.use_api, 
                                            ckpt=self.args.reward_model_path,
                                            device=self.accelerator.device,
                                            model=model
                                            )
            else:
                raise NotImplementedError()
                    

    def _set_signature_columns_if_needed(self):
        # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
        # By default, this method sets `self._signature_columns` to the model's expected inputs.
        # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
        # Instead, we set them to the columns expected by the `training_step` method, hence the override.
        if self._signature_columns is None:
            self._signature_columns = ["prompt"]

    def _get_train_sampler(self) -> Sampler:
        effective_batch_size = (
            self.args.per_device_train_batch_size
            * self.accelerator.num_processes
            * self.args.gradient_accumulation_steps
        )
        return RepeatRandomSampler(
            data_source=self.train_dataset,
            mini_repeat_count=self.num_generations,
            batch_size=effective_batch_size // self.num_generations,
            repeat_count=self.num_iterations,
            seed=self.args.seed,
        )
    
    # Get the per-token log probabilities for the completions for the model and the reference model
    def _get_per_token_logps(self, model, input_ids, attention_mask, cfgids=None):
        if cfgids is None:
            logits = model(input_ids, attention_mask=attention_mask).logits  # (B, L, V)
            logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
            input_ids = input_ids[:, 1:]  # (B, L-1), exclude the first input ID since we don't have logits for it
        else:
            cur_input_ids = torch.cat((input_ids, cfgids), dim=0)
            cur_attn = torch.cat((attention_mask, attention_mask),dim=0)
            logits = model(cur_input_ids, attention_mask=cur_attn).logits
            logits = logits[:, :-1, :]
            cond_logits, uncond_logits = logits.chunk(2)
            
            if self.args.cfg_type == 'fix':
                guidance_scale = self.guidance_scale
                logits = cond_logits - (guidance_scale-1) / guidance_scale * uncond_logits
                
            elif self.args.cfg_type == 'adaptive':
                # compute the entropy
                entropy = F.softmax(cond_logits.detach(), dim=-1)
                entropy *= torch.log(entropy)
                entropy = -torch.sum(
                    entropy[:, :, self.args.text_vocab_size:self.args.text_vocab_size + self.args.image_vocab_size], 
                    dim=-1, 
                    keepdim=True
                ) #(bsz, 1, 1)
                
                entropy = torch.where(torch.isnan(entropy), 0, entropy)
                
                guidance_scale = torch.where(entropy < self.args.entropy_bound, self.args.min_cfg, self.guidance_scale).detach()
                
                logits = cond_logits - (guidance_scale-1) / guidance_scale * uncond_logits
                
            input_ids = input_ids[:, 1:]

        # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
        per_token_logps = []
        for logits_row, input_ids_row in zip(logits, input_ids):
            log_probs = logits_row.log_softmax(dim=-1)
            token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
            per_token_logps.append(token_log_prob)
            
        return torch.stack(per_token_logps)

    def _prepare_inputs_for_sampling(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
        device = self.accelerator.device
        
        if self.args.task_type == 't2i':
            ins_prompts = [inp['text'] for inp in inputs]
            if 'meta' in inputs[0].keys():
                meta_data = [inp['meta'] for inp in inputs]
            else:
                meta_data = None
                
            instruction = self.tokenizer(
                ins_prompts,
                return_tensors="pt",
                padding='longest',
                add_special_tokens=False,
                max_length=self.max_prompt_length,
                truncation=True
            ).to(device)
            
            bsz, L, dtype = instruction['input_ids'].size(0), instruction['input_ids'].size(1), instruction['input_ids'].dtype
            
            boi = self.boi.repeat_interleave(bsz, dim=0)
            eoi = self.eoi.repeat_interleave(bsz, dim=0)
            bos = self.bos.repeat_interleave(bsz, dim=0)
            
            prompt_ids = torch.cat([bos, instruction['input_ids'], boi], dim=1).to(device)

            attention_mask1 = torch.ones((bsz, 1), dtype=torch.int32, device=prompt_ids.device)
            attention_mask2 = instruction['attention_mask']
            attention_mask3 = torch.ones((bsz, 1), dtype=torch.int32, device=prompt_ids.device)
            prompt_mask = torch.cat((attention_mask1, attention_mask2, attention_mask3), dim=1).to(device)

            if self.guidance_scale is not None and self.generate_with_cfg:
                
                cfg_content = torch.ones((bsz, L), dtype=dtype) * self.special_tokens_id["cfg"]
                cfg_content = cfg_content.to(prompt_ids.device)
                cfg_ids = torch.cat([bos, cfg_content, boi], dim=1).to(device)
                
                cfg_ids = torch.where(prompt_ids == self.special_tokens_id['pad'], self.special_tokens_id['pad'], cfg_ids)
                
                input_prompt_ids = torch.cat([prompt_ids, cfg_ids], dim=0)
                input_prompt_mask = prompt_mask.repeat(2,1)
                my_guidance_scale = self.guidance_scale
            
            else:
                cfg_ids = None
                
                input_prompt_ids = prompt_ids
                input_prompt_mask = prompt_mask
                my_guidance_scale = None
            
            input_image_ids = None
            return {
                'input_prompt_ids': input_prompt_ids,
                'input_prompt_mask': input_prompt_mask,
                'my_guidance_scale': my_guidance_scale,
                'cfg_ids': cfg_ids,
                'ins_prompts': ins_prompts,
                'meta_data': meta_data,
                'input_image_ids': input_image_ids
            }
        else :
            raise NotImplementedError()
            
    def _generation_and_compute_rewards(self, return_dict) -> dict[str, Union[torch.Tensor, Any]]:
        device = self.accelerator.device
        
        input_prompt_ids = return_dict['input_prompt_ids']
        input_prompt_mask = return_dict['input_prompt_mask']
        my_guidance_scale = return_dict['my_guidance_scale']
        cfg_ids = return_dict['cfg_ids']
        ins_prompts = return_dict['ins_prompts']
        meta_data = return_dict['meta_data']
        input_image_ids = return_dict['input_image_ids']

        # Generate completions using either vLLM or regular generation
        if self.args.use_vllm:
            # First, have main process load weights if needed
            if self.state.global_step != self._last_loaded_step:
                with unwrap_model_for_generation(
                    self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
                ) as unwrapped_model:
                    if is_compiled_module(unwrapped_model):
                        state_dict = unwrapped_model._orig_mod.state_dict()
                    else:
                        state_dict = unwrapped_model.state_dict()
                if self.accelerator.is_main_process:
                    llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
                    llm_model.load_weights(state_dict.items())
                self._last_loaded_step = self.state.global_step
            # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
            all_prompts_text = gather_object(ins_prompts)
            if self.accelerator.is_main_process:
                outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
                completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
            else:
                completion_ids = [None] * len(all_prompts_text) * self.num_generations
            # Broadcast the completions from the main process to all processes, ensuring each process receives its
            # corresponding slice.
            # if self.accelerator.is_main_process:
                # print(completion_ids)
            completion_ids = broadcast_object_list(completion_ids, from_process=0)
            process_slice = slice(
                self.accelerator.process_index * len(prompts) * self.num_generations,
                (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
            )
            completion_ids = completion_ids[process_slice]

            # Pad the completions, and concatenate them with the prompts
            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
            prompt_ids = torch.repeat_interleave(prompt_ids, self.num_generations, dim=0)
            prompt_mask = torch.repeat_interleave(prompt_mask, self.num_generations, dim=0)
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        else:
            self.model.eval()
            
            if self.args.task_type == 't2i':
                with unwrap_model_for_generation(self.model_wrapped, self.accelerator, gather_deepspeed3_params=False) as unwrapped_model:
                     with (
                            FSDP.summon_full_params(self.model_wrapped, recurse=False)
                            if self.is_fsdp_enabled
                            else nullcontext()
                        ):
                        completion = unwrapped_model.generate_image(
                            input_ids=input_prompt_ids, 
                            ori_attention_mask=input_prompt_mask,
                            cfg_type = self.args.cfg_type,
                            use_past=True, 
                            top_k=4096, 
                            top_p=0.9, 
                            guidance_scale=my_guidance_scale,
                            img_seq_len=self.args.max_completion_length,
                            image_vocab_slice=(self.args.text_vocab_size, self.args.text_vocab_size + self.args.image_vocab_size)
                        )
            else:
                raise NotImplementedError()
            
            self.model.train()
            
            completion_ids = completion
            # Mask everything after the first EOS token
            completion_mask = torch.ones((completion_ids.size(0), completion_ids.size(1)), dtype=torch.long, device=device)

            prompt_length = input_prompt_ids.size(1)
            
            if self.guidance_scale:
                prompt_completion_ids = torch.cat((input_prompt_ids.chunk(2)[0], completion_ids), dim=1)
                
                # Concatenate prompt_mask with completion_mask for logit computation
                attention_mask = torch.cat([input_prompt_mask.chunk(2)[0], completion_mask], dim=1)  # (B*G, P+C)
            else:
                prompt_completion_ids = torch.cat((input_prompt_ids, completion_ids), dim=1)

                # Concatenate prompt_mask with completion_mask for logit computation
                attention_mask = torch.cat([input_prompt_mask, completion_mask], dim=1)  # (B*G, P+C)

        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        if self.guidance_scale:
            cfg_completion_ids = torch.cat((cfg_ids, completion_ids), dim=1)
        else:
            cfg_completion_ids = None
        
        with torch.no_grad():
            if self.num_iterations > 1:
                old_per_token_logps = self._get_per_token_logps(
                    self.model, prompt_completion_ids, attention_mask, cfgids=cfg_completion_ids
                )
                old_per_token_logps = old_per_token_logps[:, prompt_length - 1 :]
            else:
                old_per_token_logps = None
            
            if self.ref_model is not None:
                ref_per_token_logps = self._get_per_token_logps(
                    self.ref_model, prompt_completion_ids, attention_mask, cfgids=cfg_completion_ids
                )
            else:
                with self.accelerator.unwrap_model(self.model).disable_adapter():
                    ref_per_token_logps = self._get_per_token_logps(
                        self.model, prompt_completion_ids, attention_mask, cfgids=cfg_completion_ids
                    )

        ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]

        # Compute the rewards
        if self.args.reverse:
            my_completion_ids = torch.flip(completion_ids, dims=[1])
        else:
            my_completion_ids = completion_ids
        
        kwargs = dict()
        kwargs['image_save_path'] = self.args.image_save_path
        kwargs['text_vocab_size'] = self.args.text_vocab_size
        
        if self.task_type == 't2i':
            if 'geneval score' in self.args.reward_list:
                if meta_data is None:
                    raise ValueError()
                rewards_per_func = torch.zeros(len(ins_prompts), 1, device=device)
                
                output_reward_func = self.reward_funcs[0](self.reward_model2, self.image_eval, prd_tokens=my_completion_ids, instruction=ins_prompts, curstep=self.state.global_step, meta_data=meta_data, **kwargs)
                rewards_per_func[:, 0] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
            elif 'dpg score' in self.args.reward_list:
                rewards_per_func = torch.zeros(len(ins_prompts), 1, device=device)
                    
                output_reward_func = self.reward_funcs[0](self.image_eval, self.dpg_eval, prd_tokens=my_completion_ids, instruction=ins_prompts, curstep=self.state.global_step, meta_data=meta_data, use_std_reward=self.args.use_std_reward, **kwargs)
                rewards_per_func[:, 0] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
                
        
        # Sum the rewards from all reward functions
        rewards = rewards_per_func.sum(dim=1)

        # Compute grouped-wise rewards
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

        # Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
        
        # Log the metrics
        reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, nn.Module):  # Module instead of PretrainedModel for compat with compiled models
                reward_func_name = reward_func.config._name_or_path.split("/")[-1]
            else:
                reward_func_name = reward_func.__name__
            self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

        self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
        self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
        
        if cfg_ids is not None:
            return {
                "prompt_ids": input_prompt_ids.chunk(2)[0],
                "prompt_mask": input_prompt_mask.chunk(2)[0],
                "cfg_ids": cfg_ids,
                "completion_ids": completion_ids,
                "completion_mask": completion_mask,
                "old_per_token_logps": old_per_token_logps,
                "ref_per_token_logps": ref_per_token_logps,
                "advantages": advantages,
            }
        return {
            "prompt_ids": input_prompt_ids,
            "prompt_mask": input_prompt_mask,
            "completion_ids": completion_ids,
            "completion_mask": completion_mask,
            "old_per_token_logps": old_per_token_logps,
            "ref_per_token_logps": ref_per_token_logps,
            "advantages": advantages,
        }
    
    
    def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
        if self.state.global_step % self.num_iterations == 0:
            
            return_dict = self._prepare_inputs_for_sampling(inputs=inputs)
            result = self._generation_and_compute_rewards(return_dict)
            if self.num_iterations > 1:
                self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = result
            
        else:
            result = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
        self._step += 1
        
        return result

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")
        # Compute the per-token log probabilities for the model
        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        # logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens
        
        if 'cfg_ids' in inputs.keys():
            cfg_ids = inputs["cfg_ids"]
            cfg_comp_ids = torch.cat([cfg_ids, completion_ids], dim=1)
        else:
            cfg_comp_ids = None

        per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, cfgids=cfg_comp_ids)
        # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
        per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1 :]

        # Compute the KL divergence between the model and the reference model
        if self.beta != 0.0:
            ref_per_token_logps = inputs["ref_per_token_logps"]
            per_token_kl = (
                torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
            )
        
        # x - x.detach() allows for preserving gradients from x
        advantages = inputs["advantages"]
        # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
        # _generate_and_score_completions) and use per_token_logps.detach() instead.
        old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
        
        if self.set_epsilon==False:
            per_token_loss = torch.exp(per_token_logps - old_per_token_logps) * advantages.unsqueeze(1)
            per_token_loss = -(per_token_loss - self.beta * per_token_kl)

        else:
            coef_1 = torch.exp(per_token_logps - old_per_token_logps)
            coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
            per_token_loss1 = coef_1 * advantages.unsqueeze(1)
            per_token_loss2 = coef_2 * advantages.unsqueeze(1)
            per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
            if self.beta != 0.0:
                per_token_loss = per_token_loss + self.beta * per_token_kl
        
        loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
        print(loss.item())
        # Log the metrics
        completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
        self._metrics["completion_length"].append(completion_length)
        
        if self.beta != 0.0:
            mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
            self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
        return loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)
            loss = loss.mean().detach()
        return loss, None, None

    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
        metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()}  # average the metrics

        # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
        # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
        if next(iter(logs.keys())).startswith("eval_"):
            metrics = {f"eval_{key}": val for key, val in metrics.items()}

        logs = {**logs, **metrics}
        if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
            super().log(logs, start_time)
        else:  # transformers<=4.46
            super().log(logs)
        self._metrics.clear()
    
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
        """
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.

        Args:
            num_training_steps (int): The number of training steps to do.
        """
        if not self.args.use_self_lr_scheduler:
            return super().create_scheduler(num_training_steps, optimizer)
        
        if self.lr_scheduler is None:
            lr_lambda = partial(lr_linear_early_drop_with_warm_up,
                                warm_up_steps=self.args.warm_up_steps,
                                convert_steps=self.args.convert_steps,
                                total_steps=num_training_steps,
                                max_lr=self.args.learning_rate,
                                convert_lr=self.args.convert_lr,
                                min_lr=self.args.min_lr
                                )
            self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer if optimizer is None else optimizer, 
                                                            lr_lambda
                                                            )
            self._created_lr_scheduler = True
        return self.lr_scheduler