
import os
import textwrap
from collections import defaultdict
from typing import Any, Callable, Optional, Union
import warnings
from unittest.mock import patch
from accelerate.utils.other import is_compiled_module
from accelerate.utils import broadcast_object_list, gather, gather_object, set_seed
import torch
import torch.utils.data
import transformers
from datasets import Dataset, IterableDataset
from packaging import version
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoProcessor,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Qwen2VLForConditionalGeneration,
    Qwen2_5_VLForConditionalGeneration,
    Trainer,
    TrainerCallback,
    is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available

from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
from trl.import_utils import is_vllm_available

import copy
from PIL import Image

if is_peft_available():
    from peft import PeftConfig, get_peft_model

if is_vllm_available():
    from vllm import LLM, SamplingParams

if is_wandb_available():
    import wandb

RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


class MultiModalGRPOTrainer(Trainer):

    def __init__(
        self,
        model: Union[str, PreTrainedModel],
        reward_funcs: Union[RewardFunc, list[RewardFunc]],
        args: GRPOConfig = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
        processing_class: Optional[PreTrainedTokenizerBase] = None,
        reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
        peft_config: Optional["PeftConfig"] = None,
        max_pixels: Optional[int] = 12845056,
        min_pixels: Optional[int] = 3136,
        attn_implementation: str = "flash_attention_2",
        use_vllm_for_gen: bool = True
    ):
        if args is None:
            model_name = model if isinstance(model, str) else model.config._name_or_path
            model_name = model_name.split("/")[-1]
            args = GRPOConfig(f"{model_name}-GRPO")

        model_init_kwargs = args.model_init_kwargs or {}
        model_init_kwargs["attn_implementation"] = attn_implementation
        if isinstance(model, str):
            model_id = model
            torch_dtype = model_init_kwargs.get("torch_dtype")
            if isinstance(torch_dtype, str) and torch_dtype != "auto":
                torch_dtype = getattr(torch, torch_dtype)
                model_init_kwargs["torch_dtype"] = torch_dtype
            model_init_kwargs["use_cache"] = False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
            if "qwen2-vl" in model_id.lower() or "qwen2_vl" in model_id.lower() or "qwen2vl" in model_id.lower():
                model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
            elif "qwen2.5-vl" in model_id.lower() or "qwen2.5_vl" in model_id.lower() or "qwen2.5vl" in model_id.lower():
                model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
            else:
                model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
        else:
            model_id = model.config._name_or_path
            if args.model_init_kwargs is not None:
                raise ValueError("model_init_kwargs can only be used when model is a string.")

        self.model_id = model_id
        self.use_vllm = use_vllm_for_gen

        if peft_config is not None:
            model = get_peft_model(model, peft_config)

        if is_deepspeed_zero3_enabled():
            if "qwen2-vl" in model_id.lower() or "qwen2_vl" in model_id.lower() or "qwen2vl" in model_id.lower():
                self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
            elif "qwen2.5-vl" in model_id.lower() or "qwen2.5_vl" in model_id.lower() or "qwen2.5vl" in model_id.lower():
                self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
            else:
                self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
        elif peft_config is None:
            self.ref_model = create_reference_model(model)
        else:
            self.ref_model = None

        if processing_class is None:
            if "qwen2-vl" in model_id.lower() or "qwen2_vl" in model_id.lower() or "qwen2vl" in model_id.lower() or "qwen2.5-vl" in model_id.lower() or "qwen2.5_vl" in model_id.lower() or "qwen2.5vl" in model_id.lower():
                processing_class = AutoProcessor.from_pretrained(model_id)
                pad_token_id = processing_class.tokenizer.pad_token_id
                processing_class.pad_token_id = pad_token_id
                processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
                processing_class.image_processor.max_pixels = max_pixels
                processing_class.image_processor.min_pixels = min_pixels
            else:
                processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
                pad_token_id = processing_class.pad_token_id

        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]
        for i, reward_func in enumerate(reward_funcs):
            if isinstance(reward_func, str):
                reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(reward_func, num_labels=1, **model_init_kwargs)
        self.reward_funcs = reward_funcs

        if reward_processing_classes is None:
            reward_processing_classes = [None] * len(reward_funcs)
        elif not isinstance(reward_processing_classes, list):
            reward_processing_classes = [reward_processing_classes]
        else:
            if len(reward_processing_classes) != len(reward_funcs):
                raise ValueError("Number of reward processing classes must match reward functions.")

        for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
            if isinstance(reward_func, PreTrainedModel):
                if reward_processing_class is None:
                    reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
                if reward_processing_class.pad_token_id is None:
                    reward_processing_class.pad_token = reward_processing_class.eos_token
                reward_func.config.pad_token_id = reward_processing_class.pad_token_id
                reward_processing_classes[i] = reward_processing_class
        self.reward_processing_classes = reward_processing_classes

        def data_collator(features):
            return features

        self.max_prompt_length = args.max_prompt_length
        self.max_completion_length = args.max_completion_length
        self.num_generations = args.num_generations
        self.beta = args.beta

        model.warnings_issued["estimate_tokens"] = True

        self._metrics = defaultdict(list)

        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            callbacks=callbacks,
            optimizers=optimizers,
        )

        self.model_accepts_loss_kwargs = False

        if self.ref_model is not None:
            if self.is_deepspeed_enabled:
                self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, PreTrainedModel):
                self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)

        if self.use_vllm:
            if self.accelerator.is_main_process:
                if torch.cuda.device_count() == 1:
                    vllm_device = "cuda:0"
                else:
                    vllm_device = f"cuda:{self.accelerator.num_processes}"
                if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
                    raise ValueError("Invalid vLLM device requested.")
                if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
                    warnings.warn("vLLM device is same as training device.")
                world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
                profiling_patch = patch("vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None)
                with world_size_patch, profiling_patch:
                    self.llm = LLM(
                        model=model.name_or_path,
                        device=vllm_device,
                        gpu_memory_utilization=0.7,
                        limit_mm_per_prompt={"image": 2},
                        enable_prefix_caching=True,
                    )
                self.sampling_params = SamplingParams(
                    temperature=args.temperature,
                    top_p=0.9,
                    top_k=50,
                    max_tokens=self.max_completion_length,
                )

            self._last_loaded_step = 0

            self.accelerator.wait_for_everyone()
        else:
            self.generation_config = GenerationConfig(
                max_new_tokens=self.max_completion_length,
                do_sample=True,
                temperature=args.temperature,
                num_return_sequences=self.num_generations,
                pad_token_id=pad_token_id,
            )

    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            self._signature_columns = ["prompt"]

    def _get_per_token_logps(self, model, **inputs):
        logits = model(**inputs).logits
        logits = logits[:, :-1, :]
        input_ids = inputs['input_ids'][:, 1:]
        per_token_logps = []
        for logits_row, input_ids_row in zip(logits, input_ids):
            log_probs = logits_row.log_softmax(dim=-1)
            token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
            per_token_logps.append(token_log_prob)
        return torch.stack(per_token_logps)
        
    def _move_model_to_vllm(self):
        with unwrap_model_for_generation(
            self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
        ) as unwrapped_model:
            state_dict = unwrapped_model.state_dict() if not is_compiled_module(unwrapped_model) else unwrapped_model._orig_mod.state_dict()
        if self.accelerator.is_main_process:
            llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
            llm_model.load_weights(state_dict.items())

    def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
        return inputs

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("GRPOTrainer does not support returning outputs")

        device = self.accelerator.device

        prompts = [x["prompt"] for x in inputs]
        prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

        if self.use_vllm:
            vllm_prompts_text = copy.deepcopy(prompts_text)
            vllm_prompts = copy.deepcopy(prompts)
        
        images = []
        for x in inputs:
            if isinstance(x["image"], list):
                images.extend([Image.open(image) if isinstance(image, str) else image for image in x["image"]])
            else:
                images = [Image.open(x["image"]) if isinstance(x["image"], str) else x["image"]]

        prompt_inputs = self.processing_class(
            text=prompts_text,
            images=images if len(images) > 0 else None,
            return_tensors="pt",
            padding=True,
            padding_side="left",
            add_special_tokens=False,
        )
        prompt_inputs = super()._prepare_inputs(prompt_inputs)

        prompt_inputs = {k: v.repeat(self.num_generations, *[1] * (v.dim() - 1)) if isinstance(v, torch.Tensor) else v for k, v in prompt_inputs.items()}
 
        if self.max_prompt_length is not None:
            prompt_ids = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
            prompt_mask = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
        else:
            prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

        if self.use_vllm:
            if len(images) > 0 and len(vllm_prompts_text) == len(images):
                prompts_text_and_vision = [{"prompt": vllm_prompt, "multi_modal_data": {"image": vllm_image}} for vllm_prompt, vllm_image in zip(vllm_prompts_text, images)]
            elif len(images) > 0 and len(vllm_prompts_text) < len(images):
                num_prompts = len(vllm_prompts_text)
                images_per_prompt = len(images) // len(vllm_prompts_text)
                split_images = [images[i * images_per_prompt: (i + 1) * images_per_prompt] for i in range(num_prompts)]
                prompts_text_and_vision = [{"prompt": vllm_prompt, "multi_modal_data": {"image": img_list}} for vllm_prompt, img_list in zip(vllm_prompts_text, split_images)]
            else:
                prompts_text_and_vision = [{"prompt": vllm_prompt} for vllm_prompt in vllm_prompts_text]
            
            prompts_text_and_vision = self.num_generations * prompts_text_and_vision
            vllm_prompts = self.num_generations * vllm_prompts

            all_prompts_text_and_vision = gather_object(prompts_text_and_vision)
            if self.accelerator.is_main_process:
                outputs = self.llm.generate(all_prompts_text_and_vision, sampling_params=self.sampling_params, use_tqdm=False)
                completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
            else:
                completion_ids = [None] * len(all_prompts_text_and_vision)
            completion_ids = broadcast_object_list(completion_ids, from_process=0)
            process_slice = slice(self.accelerator.process_index * len(vllm_prompts), (self.accelerator.process_index + 1) * len(vllm_prompts))
            completion_ids = completion_ids[process_slice]

            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        else:
            with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
                if "qwen2-vl" in self.model_id.lower() or "qwen2_vl" in self.model_id.lower() or "qwen2vl" in self.model_id.lower() or "qwen2.5-vl" in self.model_id.lower() or "qwen2.5_vl" in self.model_id.lower() or "qwen2.5vl" in self.model_id.lower():
                    num_generations = self.generation_config.num_return_sequences
                    temp_generation_config = copy.deepcopy(self.generation_config)
                    temp_generation_config.num_return_sequences = 1
                    all_completions = []
                    for i in range(num_generations):
                        completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config)
                        all_completions.append(completion)
                    max_length = max(completion.size(1) for completion in all_completions)
                    padded_completions = []
                    for completion in all_completions:
                        padding = torch.full((completion.size(0), max_length - completion.size(1)), self.processing_class.tokenizer.pad_token_id, dtype=completion.dtype, device=completion.device)
                        padded_completion = torch.cat([completion, padding], dim=1) if completion.size(1) < max_length else completion
                        padded_completions.append(padded_completion)
                    prompt_completion_ids = torch.cat(padded_completions, dim=0)
                else:
                    prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)

                prompt_length = prompt_ids.size(1)
                completion_ids = prompt_completion_ids[:, prompt_length:]

            is_eos = completion_ids == self.processing_class.eos_token_id
            eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
            eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
            sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
            completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

            attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)

            prompt_inputs["input_ids"] = prompt_completion_ids
            prompt_inputs["attention_mask"] = attention_mask

            per_token_logps = self._get_per_token_logps(model, **prompt_inputs)
            per_token_logps = per_token_logps[:, prompt_length - 1 :]

            with torch.inference_mode():
                if self.ref_model is not None:
                    ref_per_token_logps = self._get_per_token_logps(self.ref_model, **prompt_inputs)
                else:
                    with self.accelerator.unwrap_model(model).disable_adapter():
                        ref_per_token_logps = self._get_per_token_logps(model, **prompt_inputs)
            ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]

            per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

            completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
            if is_conversational(inputs[0]):
                completions = [[{"role": "assistant", "content": completion}] for completion in completions]

            prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]

            rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
            for i, (reward_func, reward_processing_class) in enumerate(zip(self.reward_funcs, self.reward_processing_classes)):
                if isinstance(reward_func, PreTrainedModel):
                    if is_conversational(inputs[0]):
                        messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
                        texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
                    else:
                        texts = [p + c for p, c in zip(prompts, completions)]
                    reward_inputs = reward_processing_class(
                        texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
                    )
                    reward_inputs = super()._prepare_inputs(reward_inputs)
                    with torch.inference_mode():
                        rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]
                else:
                    reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
                    for key in reward_kwargs:
                        for example in inputs:
                            reward_kwargs[key].extend([example[key]] * self.num_generations)
                    output_reward_func = reward_func(prompts=prompts, completions=completions, current_step=self.state.global_step, **reward_kwargs)
                    rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

            rewards = rewards_per_func.sum(dim=1)

            mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
            std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

            mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
            std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
            advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

            per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
            per_token_loss = -(per_token_loss - self.beta * per_token_kl)
            loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

            completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
            self._metrics["completion_length"].append(completion_length)

            reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
            for i, reward_func in enumerate(self.reward_funcs):
                reward_func_name = reward_func.config._name_or_path.split("/")[-1] if isinstance(reward_func, PreTrainedModel) else reward_func.__name__
                self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

            self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
            self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())

            mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
            self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

            return loss

        def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
            metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()}
            logs = {**logs, **metrics}
            if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
                super().log(logs, start_time)
            else:
                super().log(logs)
            self._metrics.clear()

        def create_model_card(
            self,
            model_name: Optional[str] = None,
            dataset_name: Optional[str] = None,
            tags: Union[str, list[str], None] = None,
        ):
            if not self.is_world_process_zero():
                return

            if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
                base_model = self.model.config._name_or_path
            else:
                base_model = None

            tags = tags or []
            if isinstance(tags, str):
                tags = [tags]

            if hasattr(self.model.config, "unsloth_version"):
                tags.append("unsloth")

            model_card = generate_model_card(
                base_model=base_model,
                model_name=model_name,
                hub_model_id=self.hub_model_id,
                dataset_name=dataset_name,
                tags=tags,
                wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
                comet_url=get_comet_experiment_url(),
                trainer_name="GRPO",
            )

            model_card.save(os.path.join(self.args.output_dir, "README.md"))