import os
import textwrap
import warnings
from collections import defaultdict
from typing import Any, Callable, Optional, Union
from unittest.mock import patch

import torch
import torch.utils.data
import math
import torch.optim as optim
import transformers
from accelerate.utils import broadcast_object_list, gather_object
from accelerate.utils.other import is_compiled_module
from datasets import Dataset, IterableDataset
from packaging import version
from torch import nn
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainerCallback,
    is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available

from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.import_utils import is_vllm_available
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
from eval_clip import clip_infer
from llama import JanusLLamaModel, JanusTrainModel, JanusEvalModel
from models.janus.models import MultiModalityCausalLM, VLChatProcessor
if is_peft_available():
    from peft import PeftConfig, get_peft_model

if is_vllm_available():
    from vllm import LLM, SamplingParams

if is_wandb_available():
    import wandb
from copy import deepcopy
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


import numpy as np
from PIL import Image
from torchvision import transforms

def center_crop_arr(pil_image, image_size):
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])


def lr_linear_early_drop_with_warm_up(x, *, warm_up_steps=45, convert_steps=500, total_steps=2000, max_lr=5e-6, convert_lr=1e-6, min_lr=2e-7):
    # return the factor
    max_factor = 1
    min_factor = min_lr / max_lr
    convert_factor = convert_lr / max_lr

    if x < warm_up_steps:
        # warm up steps
        k = max_factor / warm_up_steps
        lr = k * x
    elif warm_up_steps <= x < convert_steps:
        # high lr stage
        k = (convert_factor - max_factor) / (convert_steps - warm_up_steps)
        lr = k * x - k * warm_up_steps + max_factor
    else:
        k = (min_factor - convert_factor) / (total_steps - convert_steps)
        lr = k * x - k * convert_steps + convert_factor
    
    return lr


class GRPOTrainer(Trainer):
    
    _tag_names = ["trl", "grpo"]

    def __init__(
        self,
        model: Union[str, PreTrainedModel],
        reward_funcs: Union[RewardFunc, list[RewardFunc]],
        args: GRPOConfig = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
        processing_class: Optional[PreTrainedTokenizerBase] = None,
        reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
        peft_config: Optional["PeftConfig"] = None,
    ):
        self.args = args
        model_name = model
        self.model_name = model_name
        if args is None:
            model_name = model if isinstance(model, str) else model.config._name_or_path
            model_name = model_name.split("/")[-1]
            args = GRPOConfig(f"{model_name}-GRPO")
        
        model_init_kwargs = args.model_init_kwargs or {}
        if isinstance(model, str):
            model_id = model
            torch_dtype = model_init_kwargs.get("torch_dtype")
            if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
                pass  # torch_dtype is already a torch.dtype or "auto" or None
            elif isinstance(torch_dtype, str):  # it's a str, but not "auto"
                torch_dtype = getattr(torch, torch_dtype)
                model_init_kwargs["torch_dtype"] = torch_dtype
            else:
                raise ValueError(
                    "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
                    f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
                )
            model_init_kwargs["use_cache"] = None
            print(model_init_kwargs)
            model = JanusLLamaModel.from_pretrained(model_name, revision='main', trust_remote_code=False, torch_dtype=torch.bfloat16)
            if args.pretrain_path is not None:
                state_dict = torch.load(f"{args.pretrain_path}", map_location="cpu")
                model.load_state_dict(state_dict)
            else:
                state_dict = None
        else:
            model_id = model.config._name_or_path
            if args.model_init_kwargs is not None:
                raise ValueError(
                    "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
                    "This argument can only be used when the `model` argument is a string."
                )

        if peft_config is not None:
            print('set peft!!!!!!!!')
            model = get_peft_model(model, peft_config)

        self.ref_model = JanusLLamaModel.from_pretrained(model_name, revision='main', trust_remote_code=False, torch_dtype=torch.bfloat16)
        if args.ref_pretrain_path is not None:
            state_dict = torch.load(f"{args.ref_pretrain_path}", map_location="cpu")
            self.ref_model.load_state_dict(state_dict)
        else:
            state_dict = None
        parameter_names = [n for n, _ in self.ref_model.named_parameters()]
        for param_name in parameter_names:
            param = self.ref_model.get_parameter(param_name)
            param.requires_grad = False 
        self.ref_model.eval()

        # Processing class
        if processing_class is None:
            from models.janus.models import MultiModalityCausalLM, VLChatProcessor

            # specify the path to the model
            vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained('/mnt/prev_nas/refine_draw/models/deepseek-ai/Janus-Pro-7B')
            processing_class  = vl_chat_processor.tokenizer

        # Reward functions
        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]
        self.reward_funcs = reward_funcs

        # Reward processing class
        if reward_processing_classes is None:
            reward_processing_classes = [None] * len(reward_funcs)
        elif not isinstance(reward_processing_classes, list):
            reward_processing_classes = [reward_processing_classes]
        else:
            if len(reward_processing_classes) != len(reward_funcs):
                raise ValueError("The number of reward processing classes must match the number of reward functions.")

        for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
            if isinstance(reward_func, PreTrainedModel):
                if reward_processing_class is None:
                    reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
                if reward_processing_class.pad_token_id is None:
                    reward_processing_class.pad_token = reward_processing_class.eos_token
                # The reward model computes the reward for the latest non-padded token in the input sequence.
                # So it's important to set the pad token ID to the padding token ID of the processing class.
                reward_func.config.pad_token_id = reward_processing_class.pad_token_id
                reward_processing_classes[i] = reward_processing_class
        self.reward_processing_classes = reward_processing_classes

        # Data collator
        def data_collator(features):  # No data collation is needed in GRPO
            return features

        # Training arguments
        self.max_prompt_length = args.max_prompt_length
        self.max_completion_length = args.max_completion_length  # = |o_i| in the GRPO paper
        self.num_generations = args.num_generations  # = G in the GRPO paper
        self.num_generations_1 = args.num_generations
        self.num_generations_2 = args.num_generations - 2
        self.use_vllm = args.use_vllm

        self.beta = args.beta
        self.epsilon = args.epsilon

        model.warnings_issued["estimate_tokens"] = True

        # Initialize the metrics
        self._metrics = defaultdict(list)

        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            callbacks=callbacks,
            optimizers=optimizers,
        )

        if self.use_vllm:
            if not is_vllm_available():
                raise ImportError(
                    "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
                    "`pip install vllm` to use it."
                )

            if self.accelerator.is_main_process:
                vllm_device = self.args.vllm_device
                if vllm_device == "auto":
                    vllm_device = f"cuda:{self.accelerator.num_processes}"  # take the next GPU idx
                # Check that the requested device is available
                if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
                    raise ValueError(
                        f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
                        "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
                        "value lower than the number of GPUs available on your machine—typically, reducing it by one "
                        f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
                    )
                if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
                    warnings.warn(
                        f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
                        "behavior. It is recommended to use a dedicated device for vLLM."
                    )
                world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
                profiling_patch = patch(
                    "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
                )
                with world_size_patch, profiling_patch:
                    self.llm = LLM(
                        model=model.name_or_path,
                        device=vllm_device,
                        gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
                    )
                self.sampling_params = SamplingParams(
                    n=self.num_generations,
                    temperature=args.temperature,
                    max_tokens=self.max_completion_length,
                )

            self._last_loaded_step = 0 

            self.accelerator.wait_for_everyone()
        else:
            self.generation_config = GenerationConfig(
                max_new_tokens=self.max_completion_length,
                do_sample=True,
                temperature=args.temperature,
                num_return_sequences=self.num_generations,
                pad_token_id=processing_class.pad_token_id,
            )

        self.model_accepts_loss_kwargs = False

        self.model.add_model_tags(self._tag_names)

        if self.ref_model is not None:
            if self.is_deepspeed_enabled:
                self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

        if args.sync_ref_model:
            self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))

        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, PreTrainedModel):
                self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
        
        self.gen_transform = transforms.Compose([
            transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 384)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        
        self.set_special_tokens()
        self.set_model()

    def set_special_tokens(self, task_type='t2i'):
        if 'pro' not in self.model_name:
            model_path = "/mnt/prev_nas/refine_draw/models/deepseek-ai/Janus-1.3B"
        else:
            if '7' not in self.model_name:
                model_path = '/mnt/prev_nas/refine_draw/models/deepseek-ai/Janus-Pro-1B'
            else:
                model_path = '/mnt/prev_nas/refine_draw/models/deepseek-ai/Janus-Pro-7B'
        self.vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
        self.tokenizer = self.vl_chat_processor.tokenizer
        self.tokenizer.padding_side = 'left'

             
        self.guidance_scale = self.args.guidance_scale
        self.generate_with_cfg = self.args.generate_with_cfg
        self.set_epsilon = self.args.set_epsilon

       
    def set_model(self):
        from intern_img import InternVLReward
        self.task_type='t2i'

        if self.task_type=='t2i':
            if self.args.use_clip_score:
                self.reward_model = clip_infer(model=self.reward_model2.clip_model, transform=self.reward_model2.transform)
            elif self.args.use_internvl:
                self.reward_model = InternVLReward()
            else:
                self.reward_model = None 


        for n,p in self.ref_model.named_parameters():
            p.require_grad = False


    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            self._signature_columns = ["prompt"]

    def _get_per_token_logps(self, model, text_inputs_ids, img_ids, attention_mask, logits_to_keep=0, addcfg=True):
        inputs_embeds = model.language_model.get_input_embeddings()(text_inputs_ids)
        if img_ids.shape[0] < text_inputs_ids.shape[0]:
            new_img_ids = torch.repeat_interleave(img_ids, 2, dim=0) 
        else:
            new_img_ids = img_ids
        
        visual_embeds = model.gen_aligner(model.gen_embed(new_img_ids))
        inputs_embeds = torch.cat([inputs_embeds, visual_embeds], dim=1)
        
        if addcfg == False:
            outputs = model.language_model.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)  # (B, L, V)
            hidden_states = outputs.last_hidden_state
            logits = model.gen_head(hidden_states)
            logits = logits[:, -1-logits_to_keep:-1, :]  
            input_ids = img_ids.long()  
        
        else:
            outputs = model.language_model.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)  # (B, L, V)
            hidden_states = outputs.last_hidden_state
            logits = model.gen_head(hidden_states)
            logit_cond = logits[0::2, :]
            logit_uncond = logits[1::2, :]
            logits = logit_cond - (self.guidance_scale-1) / self.guidance_scale *logit_uncond
            logits = logits[:, -1-logits_to_keep:-1, :]  
            input_ids = img_ids.long()

        per_token_logps = []
        for logits_row, input_ids_row in zip(logits, input_ids):
            log_probs = logits_row.log_softmax(dim=-1)
            token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) # index 为什么是input_ids_row？，logits和input_ids的顺序不应该是对应的？
            per_token_logps.append(token_log_prob)
        return torch.stack(per_token_logps)

    def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
        device = self.accelerator.device
        
        pair_or_not = inputs[0]['pair']
        
        if 'img1' in inputs[0].keys():
            good_img = []
            bad_img = []
            for inp in inputs:
                good_img.append(inp['img1'])
                good_img.append(inp['img2'])
                bad_img.append(inp['img2'])
                bad_img.append(inp['img1'])
            self.num_generations = self.num_generations_2
        else:
            self.num_generations = self.num_generations_1
            good_img, good_img_id = None, None
            bad_img, bad_img_id = None, None
        
        with torch.no_grad():
            if good_img is not None:
                good_img_transform = [self.gen_transform(img) for img in good_img]
                pixel_values = torch.stack(good_img_transform, dim=0)
                _, _, all_image_ids = self.model.gen_vision_model.encode(pixel_values)
                good_img_id = all_image_ids[2]
                bad_img_transform = [self.gen_transform(img) for img in bad_img]
                pixel_values = torch.stack(bad_img_transform, dim=0)
                _, _, all_image_ids = self.model.gen_vision_model.encode(pixel_values)
                bad_img_id = all_image_ids[2]
        
        ins_prompts = []
        for inp in inputs:
            ins_prompts.extend(inp['text'])
        
        meta_data = []
        for inp in inputs:
          meta_data.extend(inp['meta'])
        prompts = ins_prompts

        allprompts = []
        for prompt in prompts:
            if 'pro' not in self.model_name:
                conversation = [
                    {
                        "role": "User",
                        "content": prompt,
                    },
                    {"role": "Assistant", "content": ""},
                ]
            else:
                conversation = [
                    {
                        "role": "<|User|>",
                        "content": prompt,
                    },
                    {"role": "<|Assistant|>", "content": ""},
                ]
            sft_format = self.vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
                conversations=conversation,
                sft_format=self.vl_chat_processor.sft_format,
                system_prompt="",
            )
            prompt = sft_format + self.vl_chat_processor.image_start_tag
            allprompts.append(prompt)
        
        instruction = self.tokenizer(
            allprompts,
            return_tensors="pt",
            padding='longest',
        ).to(device)

        

        
        bsz, L, dtype = instruction['input_ids'].size(0), instruction['input_ids'].size(1), instruction['input_ids'].dtype
        prompt_ids = instruction['input_ids']
        prompt_mask = instruction['attention_mask']


        baseline_prompts = deepcopy(prompts)
        prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
        if self.args.use_vllm:
            raise NotImplementedError
        else:
            prompt_ids = torch.repeat_interleave(prompt_ids, self.num_generations, dim=0) 
            prompt_mask = torch.repeat_interleave(prompt_mask, self.num_generations, dim=0) 
            if self.guidance_scale is not None and self.generate_with_cfg:
                set_cfg = True
                my_guidance_scale = self.guidance_scale
            else:
                set_cfg = False
                my_guidance_scale = None
            self.model.set_eval()
            with unwrap_model_for_generation(self.model, self.accelerator) as fsdp_model:
                img_ids, imgs, (text_ids, all_attention_mask) = fsdp_model.generate_image(
                    vl_chat_processor=self.vl_chat_processor,
                    input_ids=prompt_ids, attention_mask=prompt_mask, 
                    cur_step=self.state.global_step,
                    set_cfg=set_cfg, cfg_weight=my_guidance_scale,
                    img_path = '/ossfs/workspace/imgs2',
                    instruction=prompts,
                )
            self.model.set_train()


        logits_to_keep = img_ids.size(1) 

        num_generations = self.num_generations
        if good_img_id:
            num_generations += 1
            num_generations += 1
        
        prompts = [prompt for prompt in baseline_prompts for _ in range(num_generations)]
        if meta_data is  not None:
            meta_data = [md for md in meta_data for _ in range(num_generations)]
        

        if good_img_id is not None:
            all_imgs = []
            for k in range(0,len(imgs),self.num_generations):
                all_imgs += imgs[k*self.num_generations:(k+1)*self.num_generations]
                all_imgs.append(good_img[k//self.num_generations])
                all_imgs.append(bad_img[k//self.num_generations])
            imgs = all_imgs
        
        if good_img_id is not None:
            if img_ids.shape[0] < text_ids.shape[0]:
                new_text_ids = text_ids.view(-1, self.num_generations*2, text_ids.shape[-1])[:,:2,:]
                txt_ids_list = [text_ids.view(-1, self.num_generations*2, text_ids.shape[-1])]
                new_atten_mask = all_attention_mask.view(-1, self.num_generations*2, all_attention_mask.shape[-1])[:,:2,:]
                attn_mask_list = [all_attention_mask.view(-1, self.num_generations*2, all_attention_mask.shape[-1])]
            else:
                new_text_ids = text_ids.view(-1, self.num_generations, text_ids.shape[-1])[:,:1,:]
                txt_ids_list = [text_ids.view(-1, self.num_generations, text_ids.shape[-1])]
                new_atten_mask = all_attention_mask.view(-1, self.num_generations, all_attention_mask.shape[-1])[:,:1,:]
                attn_mask_list = [all_attention_mask.view(-1, self.num_generations, all_attention_mask.shape[-1])]
            
            img_ids_list = [img_ids.view(-1, self.num_generations, img_ids.shape[-1])]
            
            img_ids_list.append(good_img_id)
            txt_ids_list.append(new_text_ids)
            attn_mask_list.append(new_atten_mask)

            img_ids_list.append(bad_img_id)
            txt_ids_list.append(new_text_ids)
            attn_mask_list.append(new_atten_mask)
                
            img_ids = torch.cat(img_ids_list, dim=1)
            img_ids = img_ids.view(-1, img_ids.shape[-1])
            
            txt_ids = torch.cat(txt_ids_list, dim=1)
            txt_ids = txt_ids.view(-1, txt_ids.shape[-1])
            
            all_attention_mask = torch.cat(attn_mask_list, dim=1)
            all_attention_mask = all_attention_mask.view(-1, all_attention_mask.shape[-1])
        
        completion_mask = torch.ones((img_ids.size(0), img_ids.size(1)), dtype=torch.long, device=device)



        with torch.inference_mode():
            if self.ref_model is not None:
                ref_per_token_logps = self._get_per_token_logps(
                    self.ref_model, text_ids, img_ids, all_attention_mask, logits_to_keep=logits_to_keep, addcfg=True
                )
            else:
                with self.accelerator.unwrap_model(self.model).disable_adapter():
                    ref_per_token_logps = self._get_per_token_logps(
                        self.model, text_ids, img_ids, all_attention_mask, logits_to_keep=logits_to_keep, addcfg=True
                    )
        

       

        with torch.no_grad():
            rewards_per_func = torch.zeros(len(prompts), 2, device=device)
            for iddx in range(len(prompts)):
                stidx = iddx * num_generations
                edidx = (iddx+1) * num_generations
                cur_meta = meta_data[stidx:edidx]
                curimgs = imgs[stidx:edidx]
                curpops = prompts[stidx:edidx]
                if cur_meta[0] is None:
                    output_reward_func = self.reward_funcs[1](self.reward_model, images=curimgs, prompts=curpops)
                    rewards_per_func[stidx:edidx, 1] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
                else:
                    output_reward_func = self.reward_funcs[0](self.reward_model2, images=curimgs, instruction=curpops, meta_data=cur_meta_data_list)
                    rewards_per_func[stidx:edidx, 0] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)  
            
       
        if pair_or_not:
            num_generations *= 2
        rewards = rewards_per_func.sum(dim=1)

        mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, num_generations).std(dim=1)

        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(num_generations, dim=0)
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

        reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, nn.Module):  # Module instead of PretrainedModel for compat with compiled models
                reward_func_name = reward_func.config._name_or_path.split("/")[-1]
            else:
                reward_func_name = reward_func.__name__
            self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

        self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
        self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())

        return {
            "prompt_ids": text_ids,
            "prompt_mask": all_attention_mask,
            "completion_ids": img_ids,
            'completion_mask':completion_mask,
            "ref_per_token_logps": ref_per_token_logps,
            "advantages": advantages,
        }

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")

        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs['completion_mask']

        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        per_token_logps = self._get_per_token_logps(model, prompt_ids, completion_ids, prompt_mask, logits_to_keep, addcfg=True)
        ref_per_token_logps = inputs["ref_per_token_logps"]
        per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

        advantages = inputs["advantages"]
        if self.set_epsilon==False:
            per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
            per_token_loss = -(per_token_loss - self.beta * per_token_kl)


        else:
            coef_1 = torch.exp(per_token_logps - per_token_logps.detach())
            coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
            per_token_loss1 = coef_1 * advantages.unsqueeze(1)
            per_token_loss2 = coef_2 * advantages.unsqueeze(1)
            per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
            if self.beta != 0.0:
                per_token_loss = per_token_loss + self.beta * per_token_kl

        loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

        completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
        self._metrics["completion_length"].append(completion_length)

        mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
        self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

        return loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)
            loss = loss.mean().detach()
        return loss, None, None

    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
        metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()}  # average the metrics

        if next(iter(logs.keys())).startswith("eval_"):
            metrics = {f"eval_{key}": val for key, val in metrics.items()}

        logs = {**logs, **metrics}
        if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
            super().log(logs, start_time)
        else:  
            super().log(logs)
        self._metrics.clear()

    def create_model_card(
        self,
        model_name: Optional[str] = None,
        dataset_name: Optional[str] = None,
        tags: Union[str, list[str], None] = None,
    ):
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            model_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the model.
            dataset_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the dataset used for training.
            tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
                Tags to be associated with the model card.
        """
        if not self.is_world_process_zero():
            return

        if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
            base_model = self.model.config._name_or_path
        else:
            base_model = None

        tags = tags or []
        if isinstance(tags, str):
            tags = [tags]

        if hasattr(self.model.config, "unsloth_version"):
            tags.append("unsloth")

        citation = textwrap.dedent(
            """\
            @article{zhihong2024deepseekmath,
                title        = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
                author       = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
                year         = 2024,
                eprint       = {arXiv:2402.03300},
            }
            """
        )

        model_card = generate_model_card(
            base_model=base_model,
            model_name=model_name,
            hub_model_id=self.hub_model_id,
            dataset_name=dataset_name,
            tags=tags,
            wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
            comet_url=get_comet_experiment_url(),
            trainer_name="GRPO",
            trainer_citation=citation,
            paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
            paper_id="2402.03300",
        )

        model_card.save(os.path.join(self.args.output_dir, "README.md"))
    
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
        """
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.

        Args:
            num_training_steps (int): The number of training steps to do.
        """
        if not self.args.use_self_lr_scheduler:
            return super().create_scheduler(num_training_steps, optimizer)
        from functools import partial
        if self.lr_scheduler is None:
            lr_lambda = partial(lr_linear_early_drop_with_warm_up,
                                warm_up_steps=self.args.warm_up_steps,
                                convert_steps=self.args.convert_steps,
                                total_steps=num_training_steps,
                                max_lr=self.args.learning_rate,
                                convert_lr=self.args.convert_lr,
                                min_lr=self.args.min_lr
                                )
            self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer if optimizer is None else optimizer, 
                                                            lr_lambda
                                                            )
            self._created_lr_scheduler = True
        return self.lr_scheduler

