# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Optional, Union

import numpy as np
import torch
from transformers import Seq2SeqTrainer
from typing_extensions import override

from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler


if TYPE_CHECKING:
    from torch.utils.data import Dataset
    from transformers import PreTrainedTokenizer, ProcessorMixin
    from transformers.trainer import PredictionOutput

    from ...hparams import FinetuningArguments, DataArguments

import torch.nn.functional as F

logger = logging.get_logger(__name__)

# from transformers.trainer import *
# from transformers.trainer import _is_peft_model


def neg_log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """Compute log probs on the label ids given logits.

    We may use torch compile to speed up computing.

    Args:
        logits (torch.Tensor): logits of the model, shape (batch_size, seqlen, vocab_size)
        labels (torch.Tensor): labels of the model, shape (batch_size, seqlen)

    Returns:
        torch.Tensor: log probs of the labels, shape (batch_size, seqlen)
    """
    batch_dim = logits.shape[:-1]
    vocab_dim = logits.shape[-1]
    logits = logits.contiguous().view(-1, vocab_dim)
    labels = labels.contiguous().view(-1)
    output = F.cross_entropy(logits.float(), labels, reduction="none")
    return output.view(*batch_dim)



def compute_kl(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str = "low_var_kl") -> torch.Tensor:
    """Compute KL divergence given log_probs and ref_log_probs.

    Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150

    Args:
        log_probs: torch.Tensor
        ref_log_probs: torch.Tensor
        kl_penalty: str

    Returns:
        kl_div: torch.Tensor

    """
    log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
    if kl_penalty == "kl":
        return log_probs - ref_log_probs

    if kl_penalty == "abs":
        return (log_probs - ref_log_probs).abs()

    if kl_penalty == "mse":
        return 0.5 * (log_probs - ref_log_probs).square()

    # J. Schulman. Approximating kl divergence, 2020.
    # URL http://joschu.net/blog/kl-approx.html
    if kl_penalty == "low_var_kl":
        kl = ref_log_probs - log_probs
        kld = (kl.exp() - kl - 1).contiguous()
        return torch.clamp(kld, min=-10, max=10)

    if kl_penalty == "full":
        return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)

    raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")


class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""

    def __init__(
        self,
        finetuning_args: "FinetuningArguments",
        processor: Optional["ProcessorMixin"],
        gen_kwargs: Optional[dict[str, Any]] = None,
        data_args: Optional["DataArguments"] = None,
        teacher_model: Optional["torch.nn.Module"] = None,
        **kwargs,
    ) -> None:
        if is_transformers_version_greater_than("4.46"):
            kwargs["processing_class"] = kwargs.pop("tokenizer")
        else:
            self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")

        super().__init__(**kwargs)
        if processor is not None:
            # avoid wrong loss under gradient accumulation
            # https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
            self.model_accepts_loss_kwargs = False

        self.finetuning_args = finetuning_args
        self.data_args = data_args
        if gen_kwargs is not None:
            # https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
            self._gen_kwargs = gen_kwargs

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

        if finetuning_args.use_badam:
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore

            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)
        
        # teacher model for KD
        self.teacher_model = teacher_model

    @override
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
        return super().create_optimizer()

    @override
    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

    @override
    def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
        if self.finetuning_args.disable_shuffling:
            return torch.utils.data.SequentialSampler(self.train_dataset)

        return super()._get_train_sampler(*args, **kwargs)

    def _maybe_log_additional_info(self, additional_info):
        # import pdb; pdb.set_trace()
        # if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
        #     logs = {}
        #     for key, value in additional_info.items():
        #         logs[key] = value.item()
        #     self.log(logs)
        self.log(additional_info)
    

    def _compute_alignment_kl_loss(self, 
                                   logits_wo_policy, # (B, L1, V)
                                   logits_w_policy, # (B, L2, V)
                                   labels_wo_policy, # (B, L1)
                                   labels_w_policy #  (B, L2)
                                   ):
        # 1) build masks of valid (non-IGNORE_INDEX) positions for each
        mask_wo = labels_wo_policy != IGNORE_INDEX  # (B, L₁)
        mask_w  = labels_w_policy  != IGNORE_INDEX  # (B, L₂)

        # 2) flatten and select valid tokens separately
        B1, L1, V = logits_wo_policy.size()
        flat_logits_wo = logits_wo_policy.view(-1, V)[mask_wo.view(-1)]  # (N, V)
        flat_labels_wo = labels_wo_policy.view(-1)[mask_wo.view(-1)]     # (N,)

        B2, L2, _ = logits_w_policy.size()
        flat_logits_w  = logits_w_policy.view(-1, V)[mask_w.view(-1)]    # (N, V)
        flat_labels_w  = labels_w_policy.view(-1)[mask_w.view(-1)]       # (N,)

        # 3) ensure same number of valid tokens
        if flat_labels_wo.numel() != flat_labels_w.numel():
            raise ValueError(
                f"Mismatch in valid-token counts: wo={flat_labels_wo.numel()} vs w={flat_labels_w.numel()}"
            )

        N = flat_labels_wo.numel()
        if N == 0:
            return flat_logits_wo.new_tensor(0.0)

        # 4) compute per-token negative log-probs
        # reshape to (N, 1, V) and (N, 1) for nge helper
        neglog_wo = neg_log_probs_from_logits(
            flat_logits_wo.unsqueeze(1),    # (N, 1, V)
            flat_labels_wo.unsqueeze(1)     # (N, 1)
        ).view(N)  # -> (N,)
        neglog_w = neg_log_probs_from_logits(
            flat_logits_w.unsqueeze(1),
            flat_labels_w.unsqueeze(1)
        ).view(N)

        # 5) compute KL and return mean
        if self.finetuning_args.internalization_loss_type == "forward_kl":
            kl_vals = compute_kl(log_probs=neglog_wo, ref_log_probs=neglog_w)  # (N,)
        elif self.finetuning_args.internalization_loss_type == "reverse_kl":
            kl_vals = compute_kl(log_probs=neglog_w, ref_log_probs=neglog_wo)  # (N,)
        else:
            raise ValueError(f"Unknown internalization loss type: {self.finetuning_args.internalization_loss_type}")
        return kl_vals.mean()

    @override
    def compute_loss(self, model, inputs, *args, **kwargs):
        if self.finetuning_args.use_internalization_loss:
            
            assert self.data_args.additional_message_field_suffix is not None
            suffix = self.data_args.additional_message_field_suffix

            inputs_wo_policy = {k: v for k, v in inputs.items() if not k.endswith(suffix)}
            inputs_w_policy = {k.replace(f"__{suffix}", ""): v for k, v in inputs.items() if k.endswith(suffix)}

            ntp_loss, outputs_wo_policy = super().compute_loss(model, inputs_wo_policy, return_outputs=True, *args, **kwargs)

            if self.teacher_model is not None:
                # compute knowledge distillation loss
                with torch.no_grad():
                    _, outputs_w_policy = super().compute_loss(self.teacher_model, inputs_w_policy, return_outputs=True, *args, **kwargs)
            else:
                _, outputs_w_policy = super().compute_loss(model, inputs_w_policy, return_outputs=True, *args, **kwargs)
            
            alignment_kl_loss = self._compute_alignment_kl_loss(
                logits_wo_policy=outputs_wo_policy.logits,
                logits_w_policy=outputs_w_policy.logits,
                labels_wo_policy=inputs_wo_policy["labels"],
                labels_w_policy=inputs_w_policy["labels"]
            )
            total_loss = ntp_loss * self.finetuning_args.ntp_loss_weight + alignment_kl_loss * self.finetuning_args.internalization_loss_weight
            self._maybe_log_additional_info(
                {
                    "total_loss": total_loss.detach().item(),
                    "ntp_loss": ntp_loss.detach().item(),
                    "internalization_loss": alignment_kl_loss.detach().item(),
                    "ntp_loss_weight": self.finetuning_args.ntp_loss_weight,
                    "internalization_loss_weight": self.finetuning_args.internalization_loss_weight,
                })
            return total_loss
        else:
            if self.data_args is not None and self.data_args.additional_message_field_suffix is not None:
                suffix = self.data_args.additional_message_field_suffix
                for k in list(inputs.keys()):
                    if k.endswith(suffix):
                        inputs.pop(k)
            return super().compute_loss(model, inputs, *args, **kwargs)

    @override
    def prediction_step(
        self,
        model: "torch.nn.Module",
        inputs: dict[str, Union["torch.Tensor", Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[list[str]] = None,
        **gen_kwargs,
    ) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
        r"""Remove the prompt part in the generated tokens.

        Subclass and override to inject custom behavior.
        """
        if self.args.predict_with_generate:  # do not pass labels to model when generate
            labels = inputs.pop("labels", None)
        else:
            labels = inputs.get("labels")

        loss, generated_tokens, _ = super().prediction_step(
            model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
        )
        if generated_tokens is not None and self.args.predict_with_generate:
            generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
            generated_tokens = generated_tokens.contiguous()

        return loss, generated_tokens, labels

    def save_predictions(
        self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
    ) -> None:
        r"""Save model predictions to `output_dir`.

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
        logger.info_rank0(f"Saving prediction results to {output_prediction_file}")

        labels = np.where(
            predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
        )
        preds = np.where(
            predict_results.predictions != IGNORE_INDEX,
            predict_results.predictions,
            self.processing_class.pad_token_id,
        )

        for i in range(len(preds)):
            pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
            if len(pad_len):  # move pad token to last
                preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)

        decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
        decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
        decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)

        with open(output_prediction_file, "w", encoding="utf-8") as f:
            for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
                f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
