# Copyright 2025 The HuggingFace Team. All rights reserved.
# (License header remains the same)

import os
import textwrap
from collections import defaultdict
from typing import Any, Callable, Optional, Sized, Union, Dict, List

import torch
import torch.utils.data
import transformers
from accelerate.utils import set_seed
from datasets import Dataset, IterableDataset
from packaging import version
from transformers import (
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainerCallback,
    is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available

from trl.models import (
    create_reference_model,
    prepare_deepspeed,
    unwrap_model_for_generation,
)
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import (
    generate_model_card,
    get_comet_experiment_url,
    selective_log_softmax,
)
from trl.trainer.callbacks import SyncRefModelCallback

import warnings

warnings.filterwarnings("ignore", message=".*Could not estimate the number of tokens.*")

if is_peft_available():
    from peft import PeftConfig, get_peft_model, PeftModel

if is_wandb_available():
    import wandb

RewardFunc = Callable[[list, list], list[float]]


class GRPOTrainer(Trainer):
    """
    Trainer for the Group Relative Policy Optimization (GRPO) method, adapted for the custom TWNM model.
    """

    def __init__(
        self,
        model: PreTrainedModel,
        reward_funcs: Union[RewardFunc, list[RewardFunc]],
        ref_model: Optional[PreTrainedModel] = None,
        args: GRPOConfig = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[
            Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
        ] = None,
        processor: Optional[PreTrainedTokenizerBase] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[
            Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
        ] = (None, None),
        peft_config: Optional["PeftConfig"] = None,
    ):
        if args is None:
            model_name = model.config._name_or_path or "twnm"
            model_name = model_name.split("/")[-1]
            args = GRPOConfig(f"{model_name}-GRPO")

        if isinstance(model, str):
            raise ValueError("Please instantiate your TWNM model first.")
        if peft_config is not None:
            model = get_peft_model(model, peft_config)

        self.ref_model = ref_model
        if processor is None:
            raise ValueError("A tokenizer must be provided.")
        self.tokenizer = processor  # `Trainer` expects tokenizer, not processing_class

        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]
        self.reward_funcs = reward_funcs

        def data_collator(features):
            return features

        self.max_prompt_length = args.max_prompt_length
        self.max_completion_length = args.max_completion_length
        self.num_generations = args.num_generations
        self.beta = 0.2
        self._metrics = defaultdict(list)

        print(f"--- GRPO Trainer arguments ---\n{args}\n")

        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processor,
            callbacks=callbacks,
            optimizers=optimizers,
        )

        set_seed(args.seed, device_specific=True)

        self.generation_config = GenerationConfig(
            max_new_tokens=self.max_completion_length,
            do_sample=True,
            temperature=args.temperature,
            pad_token_id=processor.pad_token_id,
            eos_token_id=processor.eos_token_id,
        )

        if self.ref_model is not None:
            self.ref_model = self.accelerator.prepare_model(
                self.ref_model, evaluation_mode=True
            )

        if args.sync_ref_model:
            self.add_callback(
                SyncRefModelCallback(
                    ref_model=self.ref_model, accelerator=self.accelerator
                )
            )

    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            self._signature_columns = ["audio", "task", "text", "solution"]

    def _get_per_token_logps(
        self,
        model: PreTrainedModel,
        model_inputs: Dict[str, Any],
        encoder_hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        """
        Calculates per-token log probabilities, safely handling -100 indices.
        """

        outputs = model(model_inputs, encoder_hidden_states=encoder_hidden_states)
        logits = outputs["logits"]

        text_completions = model_inputs["text"]
        completion_tokens = self.tokenizer(
            text_completions,
            add_special_tokens=False,
            padding=True,
            return_tensors="pt",
        ).to(logits.device)

        # Align logits with completion tokens
        num_prompt_tokens = logits.shape[1] - completion_tokens.input_ids.shape[1]
        completion_logits = logits[:, num_prompt_tokens - 1 : -1, :]
        completion_ids = completion_tokens.input_ids

        # Mask out padding tokens with -100
        padding_mask = completion_tokens.attention_mask == 1
        completion_ids = completion_ids.masked_fill(~padding_mask, -100)

        # --- BUG FIX for `torch.gather` and negative indices ---
        # 1. Sanitize IDs for the gather operation by replacing -100 with a valid index (e.g., 0)
        ids_for_gather = completion_ids.clone()
        ids_for_gather[ids_for_gather == -100] = 0

        # 2. Manually compute log probabilities and gather using the sanitized IDs
        log_probs = torch.nn.functional.log_softmax(completion_logits, dim=-1)
        per_token_logps = torch.gather(
            log_probs, dim=-1, index=ids_for_gather.unsqueeze(-1)
        ).squeeze(-1)

        # 3. Apply the mask using the original IDs to zero out padding positions
        per_token_logps[completion_ids == -100] = 0

        return per_token_logps

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        tasks = [x["task"] for x in inputs]
        audios = torch.stack([x["audio"] for x in inputs]).to(self.accelerator.device)
        solutions = [x["solution"] for x in inputs]

        task_tokens = self.tokenizer(
            tasks,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_prompt_length,
        )
        task_tokens = {k: v.to(self.accelerator.device) for k, v in task_tokens.items()}
        prompt_length = task_tokens["input_ids"].shape[1]

        # 1. Encode audio features ONCE
        prompts_for_encoder = [t + " <AcousticTokens>" for t in tasks]
        with torch.no_grad():
            unwrapped_model = self.accelerator.unwrap_model(self.model)
            encoder_hidden_states, _ = unwrapped_model.forward_encoder(
                audios, prompts_for_encoder
            )

        # 2. Generate completions
        encoder_hidden_states_repeated = encoder_hidden_states.repeat_interleave(
            self.num_generations, dim=0
        )
        input_ids_repeated = task_tokens["input_ids"].repeat_interleave(
            self.num_generations, dim=0
        )

        model.module.change_to_policy()

        with unwrap_model_for_generation(
            model, self.accelerator
        ) as unwrapped_model_gen:
            unwrapped_model_gen.eval()
            try:
                # 在 no_grad 上下文中执行生成，以提高效率并防止意外的梯度计算
                with torch.no_grad():
                    generated_ids = unwrapped_model_gen.generate(
                        input_ids=input_ids_repeated,
                        encoder_hidden_states=encoder_hidden_states_repeated,
                        max_new_tokens=512,
                        do_sample=True,
                        temperature=0.7,
                    )
            finally:
                # 无论生成成功与否，确保将模型切回训练模式
                unwrapped_model_gen.train()

        completions_text = self.tokenizer.batch_decode(
            generated_ids, add_special_tokens=True
        )

        print("Sample generations:")
        print(completions_text[0])

        # 3. Get log probabilities for policy and reference models
        tasks_repeated = [task for task in tasks for _ in range(self.num_generations)]
        policy_model_inputs = {"task": tasks_repeated, "text": completions_text}

        per_token_logps = self._get_per_token_logps(
            model,
            policy_model_inputs,
            encoder_hidden_states=encoder_hidden_states_repeated,
        )

        with torch.inference_mode():
            if self.ref_model is not None:
                ref_per_token_logps = self._get_per_token_logps(
                    self.ref_model,
                    policy_model_inputs,
                    encoder_hidden_states=encoder_hidden_states_repeated,
                )
            else:  # Single-model, adapter-switching logic
                model.module.change_to_default()
                ref_per_token_logps = self._get_per_token_logps(
                    model,
                    policy_model_inputs,
                    encoder_hidden_states=encoder_hidden_states_repeated,
                )
                model.module.change_to_policy()
                # print(f"ref_per_token_logps shape: {ref_per_token_logps.shape}")
                # print(f"ref_per_token_logps: {ref_per_token_logps}")

        # print(f"per_token_logps: {per_token_logps}")

        # 4. Compute KL divergence, rewards, and loss
        per_token_kl = (
            torch.exp(ref_per_token_logps - per_token_logps)
            - (ref_per_token_logps - per_token_logps)
            - 1
        )

        completions_for_reward = [
            [{"role": "assistant", "content": text}] for text in completions_text
        ]
        solutions_repeated = [
            sol for sol in solutions for _ in range(self.num_generations)
        ]

        rewards_per_func = [
            torch.tensor(
                func(completions=completions_for_reward, solution=solutions_repeated),
                dtype=torch.float32,
                device=self.accelerator.device,
            )
            for func in self.reward_funcs
        ]
        rewards = torch.stack(rewards_per_func).sum(dim=0)

        mean_grouped_rewards = (
            rewards.view(-1, self.num_generations)
            .mean(dim=1)
            .repeat_interleave(self.num_generations, dim=0)
        )
        std_grouped_rewards = (
            rewards.view(-1, self.num_generations)
            .std(dim=1)
            .repeat_interleave(self.num_generations, dim=0)
        )
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

        # print("-"*50)

        # print(f"rewards: {rewards}")
        # print(f"mean_grouped_rewards: {mean_grouped_rewards}")
        # print(f"advantages: {advantages}")

        completion_mask = (generated_ids != self.tokenizer.pad_token_id).int()

        logps_len, mask_len = per_token_logps.shape[1], completion_mask.shape[1]
        # print(f"logps_len: {logps_len}, mask_len: {mask_len}")
        min_len = min(logps_len, mask_len)

        per_token_logps, completion_mask, per_token_kl = (
            per_token_logps[:, :min_len],
            completion_mask[:, :min_len],
            per_token_kl[:, :min_len],
        )

        per_token_loss = torch.exp(
            per_token_logps - per_token_logps.detach()
        ) * advantages.unsqueeze(1)

        # print(f"exp: {torch.exp(per_token_logps - per_token_logps.detach())}")
        # print(f"Per token raw loss: {per_token_loss}")

        """
        rewards: tensor([2., 1., 1.], device='cuda:4')
        mean_grouped_rewards mean: 1.3333, std: 0.0000
        advantages: tensor([ 1.1545, -0.5773, -0.5773], device='cuda:4')
        per_token_logps mean: -16.5813
        exp: 1.0000
        Per token raw loss: -0.0000
        Per token loss: 0.0000, completion length: 1018
        
        """
        # print(f"completion_mask: {completion_mask}, completion length: {completion_mask.sum()}")

        per_token_loss = -(per_token_loss - self.beta * per_token_kl)
        loss = (
            (per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
        ).mean()
        # print(f"Per token kl: {per_token_kl}")

        # print(f"Per token loss: {per_token_loss}, completion length: {completion_mask.sum()}")
        # print(f"all token loss: {((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()}")

        # Log metrics
        with torch.no_grad():
            self._metrics["completion_length"].append(
                self.accelerator.gather_for_metrics(completion_mask.sum(1))
                .float()
                .mean()
                .item()
            )
            self._metrics["reward"].append(
                self.accelerator.gather_for_metrics(rewards).mean().item()
            )
            self._metrics["reward_std"].append(
                self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()
            )
            mean_kl = (
                (per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
            ).mean()
            self._metrics["kl"].append(
                self.accelerator.gather_for_metrics(mean_kl).mean().item()
            )

        return loss

    # (The log and create_model_card methods remain the same)
    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
        metrics = {
            key: sum(val) / len(val) for key, val in self._metrics.items()
        }  # average the metrics
        logs = {**logs, **metrics}
        if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
            # 2. 在 super().log 调用中把 start_time 传递下去
            super().log(logs, start_time)
        else:  # transformers<=4.46
            super().log(logs)  # 旧版本不需要 start_time，保持不变
        self._metrics.clear()

    def create_model_card(self, model_name: Optional[str] = None, **kwargs):
        # (This method can be kept as is or simplified if not needed)
        pass

    def save_model(
        self, output_dir: Optional[str] = None, _internal_call: bool = False
    ):
        if output_dir is None:
            output_dir = self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)

        # 只保存 LoRA 适配器（假设 LoRA 挂在 model.decoder 上）
        if hasattr(self.model, "decoder") and isinstance(self.model.decoder, PeftModel):
            adapter_dir = os.path.join(output_dir, "policy")  # 统一存到 policy 子目录
            os.makedirs(adapter_dir, exist_ok=True)
            self.model.decoder.save_pretrained(
                adapter_dir, selected_adapters=["policy"]
            )  # 这一步只会写出 LoRA 小权重
        else:
            # 兜底：若顶层就是 PeftModel
            if isinstance(self.model, PeftModel):
                self.model.save_pretrained(output_dir)
            else:
                # 实在找不到 LoRA，就退回父类（不推荐）
                super().save_model(output_dir, _internal_call=_internal_call)

        # 保存 tokenizer（小文件）
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
