# Copyright 2025 The HuggingFace Team. All rights reserved.
# (License header remains the same)

import os
import textwrap
from collections import defaultdict
from typing import Any, Callable, Optional, Sized, Union, Dict, List 

import torch
import torch.utils.data
import transformers
from accelerate.utils import set_seed
from datasets import Dataset, IterableDataset
from packaging import version
from transformers import (
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainerCallback,
    is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available

from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, selective_log_softmax
from trl.trainer.callbacks import SyncRefModelCallback

import warnings
warnings.filterwarnings("ignore", message=".*Could not estimate the number of tokens.*")

if is_peft_available():
    from peft import PeftConfig, get_peft_model, PeftModel

if is_wandb_available():
    import wandb

RewardFunc = Callable[[list, list], list[float]]

class GRPOTrainer(Trainer):
    """
    Trainer for the Group Relative Policy Optimization (GRPO) method, adapted for the custom TWNM model.
    """

    def __init__(
        self,
        model: PreTrainedModel,
        reward_funcs: Union[RewardFunc, list[RewardFunc]],
        ref_model: Optional[PreTrainedModel] = None,
        args: GRPOConfig = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
        peft_config: Optional["PeftConfig"] = None,
    ):
        if args is None:
            model_name = model.config._name_or_path or "twnm"
            model_name = model_name.split("/")[-1]
            args = GRPOConfig(f"{model_name}-GRPO")

        if isinstance(model, str): raise ValueError("Please instantiate your TWNM model first.")
        if peft_config is not None: model = get_peft_model(model, peft_config)

        self.ref_model = ref_model
        if tokenizer is None: raise ValueError("A tokenizer must be provided.")
        self.tokenizer = tokenizer # `Trainer` expects tokenizer, not processing_class
        
        if not isinstance(reward_funcs, list): reward_funcs = [reward_funcs]
        self.reward_funcs = reward_funcs
        
        def data_collator(features): return features

        self.max_prompt_length = args.max_prompt_length
        self.max_completion_length = args.max_completion_length
        self.num_generations = args.num_generations
        self.beta = args.beta
        self._metrics = defaultdict(list)

        super().__init__(
            model=model, args=args, data_collator=data_collator, train_dataset=train_dataset,
            eval_dataset=eval_dataset, processing_class=tokenizer, callbacks=callbacks, optimizers=optimizers,
        )

        set_seed(args.seed, device_specific=True)

        self.generation_config = GenerationConfig(
            max_new_tokens=self.max_completion_length, do_sample=True, temperature=args.temperature,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

        if self.ref_model is not None:
            self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
        
        if args.sync_ref_model:
            self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
        
    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            self._signature_columns = ["audio", "task", "text", "solution"]

    def _get_per_token_logps(self, model: PreTrainedModel, model_inputs: Dict[str, Any], encoder_hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Calculates per-token log probabilities, safely handling -100 indices.
        """
        outputs = model(model_inputs, encoder_hidden_states=encoder_hidden_states)
        logits = outputs["logits"]
        
        text_completions = model_inputs['text']
        completion_tokens = self.tokenizer(
            text_completions, add_special_tokens=False, padding=True, return_tensors="pt"
        ).to(logits.device)
        
        # Align logits with completion tokens
        num_prompt_tokens = logits.shape[1] - completion_tokens.input_ids.shape[1]
        completion_logits = logits[:, num_prompt_tokens-1:-1, :]
        completion_ids = completion_tokens.input_ids
        
        # Mask out padding tokens with -100
        padding_mask = completion_tokens.attention_mask == 1
        completion_ids = completion_ids.masked_fill(~padding_mask, -100)

        # --- BUG FIX for `torch.gather` and negative indices ---
        # 1. Sanitize IDs for the gather operation by replacing -100 with a valid index (e.g., 0)
        ids_for_gather = completion_ids.clone()
        ids_for_gather[ids_for_gather == -100] = 0
        
        # 2. Manually compute log probabilities and gather using the sanitized IDs
        log_probs = torch.nn.functional.log_softmax(completion_logits, dim=-1)
        per_token_logps = torch.gather(log_probs, dim=-1, index=ids_for_gather.unsqueeze(-1)).squeeze(-1)
        
        # 3. Apply the mask using the original IDs to zero out padding positions
        per_token_logps[completion_ids == -100] = 0
    
        return per_token_logps

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        tasks = [x["task"] for x in inputs]
        audios = torch.stack([x["audio"] for x in inputs]).to(self.accelerator.device)
        solutions = [x["solution"] for x in inputs]

        task_tokens = self.tokenizer(
            tasks, return_tensors="pt", padding=True, truncation=True, max_length=self.max_prompt_length
        )
        task_tokens = {k: v.to(self.accelerator.device) for k, v in task_tokens.items()}
        prompt_length = task_tokens["input_ids"].shape[1]
        
        # 1. Encode audio features ONCE
        prompts_for_encoder = [t + " <AcousticTokens>" for t in tasks]
        with torch.no_grad():
            unwrapped_model = self.accelerator.unwrap_model(self.model)
            encoder_hidden_states, _ = unwrapped_model.forward_encoder(audios, prompts_for_encoder)
            encoder_hidden_states = encoder_hidden_states.to(unwrapped_model.decoder.dtype)

        # 2. Generate completions
        encoder_hidden_states_repeated = encoder_hidden_states.repeat_interleave(self.num_generations, dim=0)
        input_ids_repeated = task_tokens["input_ids"].repeat_interleave(self.num_generations, dim=0)
        attention_mask_repeated = task_tokens["attention_mask"].repeat_interleave(self.num_generations, dim=0)
        prompts_for_generate = [p for p in prompts_for_encoder for _ in range(self.num_generations)]

        with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model_gen:
            generated_ids = unwrapped_model_gen.generate(
                input_ids=input_ids_repeated, attention_mask=attention_mask_repeated,
                encoder_hidden_states=encoder_hidden_states_repeated, prompt=prompts_for_generate,
                generation_config=self.generation_config,
            )

        # === MODIFIED: robust completion extraction without relying on token-count slicing ===
        # 过去做法：用 prompt_length 直接切 generated_ids，容易与 inputs_embeds 的前缀长度不一致 → 错位
        # 新做法：直接解码整段序列，再按字符串级别把 task 前缀（若存在）剥离，得到纯 completion 文本
        full_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        tasks_repeated = [task for task in tasks for _ in range(self.num_generations)]
        completions_text = [
            ft[len(t):].lstrip() if ft.startswith(t) else ft
            for ft, t in zip(full_texts, tasks_repeated)
        ]
        # 为了构造 mask，用 tokenizer 重新对 completion 文本分词
        completion_tokens_for_mask = self.tokenizer(
            completions_text, add_special_tokens=False, padding=True, return_tensors="pt"
        ).to(self.accelerator.device)
        completion_ids = completion_tokens_for_mask.input_ids
        completion_mask = completion_tokens_for_mask.attention_mask.int()

        print("Sample generations:")
        print(completions_text)

        # 3. Get log probabilities for policy and reference models
        tasks_repeated = [task for task in tasks for _ in range(self.num_generations)]
        policy_model_inputs = {"task": tasks_repeated, "text": completions_text}

        per_token_logps = self._get_per_token_logps(
            model, policy_model_inputs, encoder_hidden_states=encoder_hidden_states_repeated
        )

        with torch.inference_mode():
            if self.ref_model is not None:
                ref_per_token_logps = self._get_per_token_logps(
                    self.ref_model, policy_model_inputs, encoder_hidden_states=encoder_hidden_states_repeated
                )
            else: # Single-model, adapter-switching logic
                with self.accelerator.unwrap_model(model).disable_adapters():
                    ref_per_token_logps = self._get_per_token_logps(
                        model, policy_model_inputs, encoder_hidden_states=encoder_hidden_states_repeated
                    )
        
        # 4. Compute KL divergence, rewards, and loss
        per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
        
        completions_for_reward = [[{"role": "assistant", "content": text}] for text in completions_text]
        solutions_repeated = [sol for sol in solutions for _ in range(self.num_generations)]
        
        rewards_per_func = [
            torch.tensor(func(completions=completions_for_reward, solution=solutions_repeated), dtype=torch.float32, device=self.accelerator.device)
            for func in self.reward_funcs
        ]
        rewards = torch.stack(rewards_per_func).sum(dim=0)
        
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1).repeat_interleave(self.num_generations, dim=0)
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

        # (MODIFIED) completion_mask 已基于 attention_mask 构造，避免依赖 pad_id
        # completion_mask = (completion_ids != self.tokenizer.pad_token_id).int()

        logps_len, mask_len = per_token_logps.shape[1], completion_mask.shape[1]
        min_len = min(logps_len, mask_len)
        
        per_token_logps, completion_mask, per_token_kl = per_token_logps[:, :min_len], completion_mask[:, :min_len], per_token_kl[:, :min_len]

        per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
        per_token_loss = -(per_token_loss - self.beta * per_token_kl)
        loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

        # Log metrics
        with torch.no_grad():
            self._metrics["completion_length"].append(self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item())
            self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
            self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
            mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
            self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

        return loss

    # (The log and create_model_card methods remain the same)
    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
        metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()}  # average the metrics
        logs = {**logs, **metrics}
        if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
            # 2. 在 super().log 调用中把 start_time 传递下去
            super().log(logs, start_time)
        else:  # transformers<=4.46
            super().log(logs) # 旧版本不需要 start_time，保持不变
        self._metrics.clear()

    def create_model_card(self, model_name: Optional[str] = None, **kwargs):
        # (This method can be kept as is or simplified if not needed)
        pass