# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from transformers import Seq2SeqTrainer
from torch.utils.data import DataLoader
from typing_extensions import override
from .scorer import Llama_Scorer
from .custom_batchsampler import CustomBatchSampler

from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler


if TYPE_CHECKING:
    from torch.utils.data import Dataset
    from transformers import PreTrainedTokenizer, ProcessorMixin
    from transformers.trainer import PredictionOutput

    from ...hparams import DataArguments, FinetuningArguments


logger = logging.get_logger(__name__)


class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""

    def __init__(
        self,
        reference_model, 
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        processor: Optional["ProcessorMixin"],
        gen_kwargs: Optional[dict[str, Any]] = None,
        **kwargs,
    ) -> None:
        if is_transformers_version_greater_than("4.46"):
            kwargs["processing_class"] = kwargs.pop("tokenizer")
        else:
            self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")

        super().__init__(**kwargs)
        if processor is not None:
            # avoid wrong loss under gradient accumulation
            # https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
            self.model_accepts_loss_kwargs = False

        self.finetuning_args = finetuning_args
        if gen_kwargs is not None:
            # https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
            self._gen_kwargs = gen_kwargs

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

        if finetuning_args.use_badam:
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)
        if reference_model != None:
            self.reference_model = reference_model.to(dtype=next(self.model.parameters()).dtype, device=self.model.device)
            self.reference_model.resize_token_embeddings(len(self.processing_class))
            self.reference_model.config.vocab_size = len(self.processing_class)
        self.token_method = data_args.token_method
        self.token_ratio = data_args.token_ratio
        self.data_ratio = getattr(data_args, "data_ratio", 1.0)
        self.plug = getattr(data_args, "plug", None)
        self.wise_lambda = getattr(data_args, "wise_lambda", 0.5)
        self.last_attn = None
        self._attention_hook_handle = None

    @override
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
        return super().create_optimizer()

    @override
    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

    def get_train_dataloader(self):
        train_dataloader = DataLoader(
            dataset=self.train_dataset, 
            batch_sampler=CustomBatchSampler(self.train_dataset.pruning_sampler(), self.args.per_device_train_batch_size, False), 
            collate_fn=self.data_collator
        )
        return train_dataloader


    def hook_fn(self, module, input, output):
        if isinstance(output, tuple) and len(output) > 1:
            self.last_attn = output[1]
        else:
            self.last_attn = output

    def _resolve_last_attention_module(self, model: "torch.nn.Module"):
        """
        Try to locate the last decoder layer's attention module in a robust way, compatible with
        PEFT-wrapped models and different backbone conventions (e.g., Qwen/LLaMA using `model.layers`,
        GPT-style using `transformer.h`). Returns the attention module or None if not found.
        """
        try:
            core_model = model.module if hasattr(model, "module") else model
            base = getattr(core_model, "get_base_model", lambda: core_model)()
            backbone = (
                getattr(base, "model", None)
                or getattr(base, "transformer", None)
                or getattr(base, "base_model", None)
            )
            if backbone is None:
                return None

            layers = getattr(backbone, "layers", None) or getattr(backbone, "h", None)
            if layers is None or len(layers) == 0:
                return None

            last_layer = layers[-1]
            attn = (
                getattr(last_layer, "self_attn", None)
                or getattr(last_layer, "attention", None)
                or getattr(last_layer, "attn", None)
            )
            return attn
        except Exception:
            return None
    def _resolve_last_layer_module(self, model: "torch.nn.Module"):
        """
        Return the last decoder layer module (same path lookup as _resolve_last_attention_module),
        useful for registering hooks to capture the last layer hidden states without opening hidden_states in forward return.
        """
        try:
            core_model = model.module if hasattr(model, "module") else model
            base = getattr(core_model, "get_base_model", lambda: core_model)()
            backbone = (
                getattr(base, "model", None)
                or getattr(base, "transformer", None)
                or getattr(base, "base_model", None)
            )
            if backbone is None:
                return None
            layers = getattr(backbone, "layers", None) or getattr(backbone, "h", None)
            if layers is None or len(layers) == 0:
                return None
            return layers[-1]
        except Exception:
            return None

    @torch.no_grad()
    def _compute_ppl_entropy(self, logits: "torch.Tensor", shifted_labels: "torch.Tensor", attn_mask_shift: "torch.Tensor") -> tuple["np.ndarray", "np.ndarray"]:
        """
        Compute per-sample ppl and entropy with masks.
        - logits: [B, L-1, V]
        - shifted_labels: [B, L-1]
        - attn_mask_shift: [B, L-1] (1 for valid positions)
        """
        B, T, V = logits.shape
        log_probs = torch.log_softmax(logits, dim=-1)
        probs = torch.softmax(logits, dim=-1)

        # token-level NLL (ignore positions where label == -100)
        nll = torch.zeros((B, T), device=logits.device, dtype=logits.dtype)
        valid_label_mask = (shifted_labels != -100) & (attn_mask_shift > 0)
        if valid_label_mask.any():
            nll[valid_label_mask] = -log_probs[valid_label_mask, shifted_labels[valid_label_mask]]
        # per-sample loss and ppl
        token_counts = valid_label_mask.sum(dim=1).clamp_min(1)
        loss_per_sample = (nll.sum(dim=1) / token_counts).to(dtype=torch.float32)
        ppl = torch.exp(loss_per_sample).detach().cpu().numpy()

        # token entropy H(p) = -sum p log p, averaged over valid attention positions
        token_entropy = (-probs * log_probs).sum(dim=-1)  # [B, T]
        valid_attn_mask = (attn_mask_shift > 0)
        ent_counts = valid_attn_mask.sum(dim=1).clamp_min(1)
        entropy_per_sample = (token_entropy * valid_attn_mask).sum(dim=1) / ent_counts
        entropy = entropy_per_sample.detach().cpu().numpy()
        return ppl, entropy

    def _dynamic_threshold_search_bisect(self, ppl, entropy,
                                            keep_ratio, tol=0.01, max_iter=10):
        """
        Joint bisection on both α (PPL axis) and β (Entropy axis).
        - Q2: ppl > Q_{1-α}(ppl) and entropy <= Q_{β}(entropy)
        - Q4: ppl <= Q_{α}(ppl) and entropy > Q_{1-β}(entropy)
        Returns (ppl_mid, ent_mid, Q2_mask, Q4_mask).
        """

        N = len(ppl)
        if N == 0:
            return None, None, np.zeros((0,), dtype=bool), np.zeros((0,), dtype=bool)

        def _quantile(x, q):
            q = float(min(max(q, 0.0), 1.0))
            return float(np.quantile(x, q))

        alpha_low, alpha_high = 0.0, 0.49
        beta_low, beta_high = 0.0, 0.49
        best = None

        for _ in range(max_iter):
            alpha = (alpha_low + alpha_high) / 2.0
            beta = (beta_low + beta_high) / 2.0

            ppl_hi = _quantile(ppl, 1.0 - alpha)
            ppl_lo = _quantile(ppl, alpha)
            ent_lo = _quantile(entropy, beta)
            ent_hi = _quantile(entropy, 1.0 - beta)

            Q2 = (ppl > ppl_hi) & (entropy <= ent_lo)
            Q4 = (ppl <= ppl_lo) & (entropy > ent_hi)
            keep_frac = (Q2.sum() + Q4.sum()) / max(1, N)

            best = (alpha, beta, ppl_hi, ppl_lo, ent_lo, ent_hi, Q2, Q4)
            if abs(keep_frac - keep_ratio) < tol:
                break
            if keep_frac < keep_ratio:
                alpha_low, beta_low = alpha, beta
            else:
                alpha_high, beta_high = alpha, beta

        alpha, beta, ppl_hi, ppl_lo, ent_lo, ent_hi, Q2, Q4 = best
        ppl_mid = (ppl_lo + ppl_hi) / 2.0
        ent_mid = (ent_lo + ent_hi) / 2.0
        return ppl_mid, ent_mid, Q2, Q4


    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        
        # Move inputs to the correct device
        inputs = {k: v if k == "text" else v.to(model.device) for k, v in inputs.items()}

        indices = inputs['indices']
        weights = inputs['weights']
        self.indices = indices

        shifted_labels = inputs['labels'][:, 1:].contiguous().detach()
        attention_mask = inputs['attention_mask'][:, 1:].contiguous().detach()
        # Compute reference loss per token
        loss_fct = torch.nn.CrossEntropyLoss(reduction='none')#Equation One

        # Forward pass with target model

        outputs = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            labels=inputs['labels'],
            return_dict=True,
            output_hidden_states=False,
            output_attentions=False,
        )

        target_logits = outputs.logits[:, :-1, :]
        target_token_loss = loss_fct(#Equation Two
            target_logits.reshape(-1, target_logits.size(-1)),
            shifted_labels.reshape(-1)
        ).reshape(shifted_labels.size())

        # Apply attention mask
        target_token_loss = target_token_loss * attention_mask

        
        # Build sub-batch B' indices according to data_ratio for both branches
        batch_size = inputs['input_ids'].size(0)
        keep_ratio = float(self.train_dataset.ratio) if hasattr(self.train_dataset, "ratio") else float(getattr(self, "data_ratio", 1.0))
        subset_size = max(1, int(round(batch_size * keep_ratio)))
        perm = torch.randperm(batch_size, device=model.device)
        idx_sub = perm[:subset_size]

        # Plugin: wisely (sample selection Q2+Q4 per-batch, then Q2 token pruning on instruction region)
        if getattr(self, "plug", None) == "wisely":
            # First forward already computed as `outputs`. Use it to build ppl/entropy per sample
            logits_full = outputs.logits[:, :-1, :].contiguous()
            ppl_np, ent_np = self._compute_ppl_entropy(
                logits_full, shifted_labels, attention_mask
            )
            # keep_ratio for samples in current batch: (Q2 + Q4) / all
            keep_ratio_samples = float(self.data_ratio)
            ppl_th, ent_th, Q2_mask, Q4_mask = self._dynamic_threshold_search_bisect(ppl_np, ent_np, keep_ratio_samples)
            import numpy as _np
            Q2_idx = _np.where(Q2_mask)[0]
            Q4_idx = _np.where(Q4_mask)[0]
            Qkeep_idx = _np.concatenate([Q2_idx, Q4_idx], axis=0)

            # Ensure at least K samples kept, where K = round(data_ratio * batch_size)
            K_keep = max(1, int(round(batch_size * float(self.data_ratio))))
            K_keep = min(K_keep, batch_size)
            extra_added_idx = None  # track augmented indices for Q2/Q4 attribution
            if Qkeep_idx.size < K_keep and ppl_np.size > 0:
                # normalize to [0,1]
                def _norm(x: _np.ndarray):
                    x_min = float(_np.min(x))
                    x_max = float(_np.max(x))
                    if x_max - x_min < 1e-12:
                        return _np.zeros_like(x)
                    return (x - x_min) / (x_max - x_min)

                ppl_n = _norm(ppl_np)
                ent_n = _norm(ent_np)
                s2 = ppl_n - ent_n  # prefer Q2 (high ppl, low entropy)
                s4 = ent_n - ppl_n  # prefer Q4 (high entropy, low ppl)
                wise_score = _np.maximum(s2, s4)

                all_idx = _np.arange(len(ppl_np))
                already = set(Qkeep_idx.tolist())
                remain = _np.array([i for i in all_idx if i not in already], dtype=_np.int64)
                if remain.size > 0:
                    need = K_keep - Qkeep_idx.size
                    # select top-need from remain by wise_score
                    order = _np.argsort(-wise_score[remain])  # descending
                    extra = remain[order[:min(need, remain.size)]]
                    Qkeep_idx = _np.concatenate([Qkeep_idx, extra], axis=0)
                    extra_added_idx = extra

                # mark that augmentation occurred
                augmented = True
            else:
                augmented = False
                extra_added_idx = _np.array([], dtype=_np.int64)

            # Build subset indices tensor
            idx_sub = torch.tensor(Qkeep_idx if Qkeep_idx.size > 0 else _np.arange(min(1, batch_size)), device=model.device, dtype=torch.long)

            # Compute token pruning budget for this batch
            total_tokens_batch = attention_mask.sum().item()
            target_keep_tokens = int(round(total_tokens_batch * float(self.data_ratio) * float(self.token_ratio)))
            drop_target = max(0, total_tokens_batch - target_keep_tokens)

            # Tokens dropped by removing Q1+Q3 samples entirely
            keep_mask_batch = torch.zeros_like(attention_mask, dtype=attention_mask.dtype)
            if idx_sub.numel() > 0:
                keep_mask_batch[idx_sub] = attention_mask[idx_sub]
            dropped_tokens_by_samples = int((attention_mask.sum() - keep_mask_batch.sum()).item())

            remaining_drop_for_Q2 = max(0, drop_target - dropped_tokens_by_samples)

            # Compute raw CE per-token on kept set (reuse first forward logits; no second forward)
            idx_sub_dev = idx_sub.to(logits_full.device)
            logits_kept = torch.index_select(logits_full, 0, idx_sub_dev)
            shifted_labels_kept = inputs['labels'][idx_sub][:, 1:].contiguous().detach()
            attn_mask_shift_kept = inputs['attention_mask'][idx_sub][:, 1:].contiguous().detach()
            ce_token_loss_raw = loss_fct(
                logits_kept.reshape(-1, logits_kept.size(-1)),
                shifted_labels_kept.reshape(-1)
            ).reshape(shifted_labels_kept.size())
            base_valid_mask = (attn_mask_shift_kept > 0)

            # If still need to drop tokens, drop from Q2 samples' instruction region by token method ordering
            if remaining_drop_for_Q2 > 0 and self.token_ratio!=1.0:
                # determine pruning candidates: original Q2 plus augmented Q2-like (s2>=s4) among extras
                kept_indices_np = idx_sub.detach().cpu().numpy()
                map_to_kept = {v: i for i, v in enumerate(kept_indices_np.tolist())}
                # start with original Q2
                q2_all = Q2_idx.copy()
                # add augmented Q2-like if any
                if augmented and extra_added_idx is not None and extra_added_idx.size > 0:
                    # reuse normalization
                    def _norm_local(x: _np.ndarray):
                        x_min = float(_np.min(x))
                        x_max = float(_np.max(x))
                        if x_max - x_min < 1e-12:
                            return _np.zeros_like(x)
                        return (x - x_min) / (x_max - x_min)
                    ppl_n_aug = _norm_local(ppl_np)
                    ent_n_aug = _norm_local(ent_np)
                    s2_aug = ppl_n_aug - ent_n_aug
                    s4_aug = ent_n_aug - ppl_n_aug
                    aug_q2 = extra_added_idx[s2_aug[extra_added_idx] >= s4_aug[extra_added_idx]]
                    if aug_q2.size > 0:
                        q2_all = _np.concatenate([q2_all, aug_q2], axis=0)

                Q2_in_kept = [map_to_kept[i] for i in q2_all if i in map_to_kept]
                if len(Q2_in_kept) > 0:
                    Q2_in_kept_t = torch.tensor(Q2_in_kept, device=model.device, dtype=torch.long)

                    # Build instruction-only mask: here we approximate instruction tokens as non -100 positions from prompt in labels
                    # Under current template, when train_on_prompt=False, prompt tokens are mostly -100, so we don't prune prompt; if only instruction pruning is needed, mark position info in Dataset
                    instr_mask = (shifted_labels_kept != -100) & (attn_mask_shift_kept > 0)

                    # Scoring tokens by current token_method

                    wise_lambda = float(self.wise_lambda)
                    ppl_token = torch.exp(ce_token_loss_raw.to(dtype=torch.float32)).to(dtype=ce_token_loss_raw.dtype)
                    ppl_token = ppl_token * base_valid_mask
                    # immediate neighbors (left/right) masked by valid region
                    left = F.pad(ppl_token[:, :-1], (1, 0)) * F.pad(base_valid_mask[:, :-1].to(dtype=ppl_token.dtype), (1, 0))
                    right = F.pad(ppl_token[:, 1:], (0, 1)) * F.pad(base_valid_mask[:, 1:].to(dtype=ppl_token.dtype), (0, 1))

                    s_wise = (1.0 - wise_lambda) * ppl_token + wise_lambda * (left + right)
                    scores_kept = (s_wise * base_valid_mask)

                    # Align scores_kept length to mask length to avoid shape mismatch
                    T_target = base_valid_mask.size(1)
                    if scores_kept.size(1) != T_target:
                        cur_len = scores_kept.size(1)
                        if cur_len < T_target:
                            pad = T_target - cur_len
                            fill_val = torch.finfo(scores_kept.dtype).min
                            scores_kept = F.pad(scores_kept, (0, pad), value=fill_val)
                        else:
                            scores_kept = scores_kept[:, :T_target]

                    # Collect candidate positions only within Q2 samples and instruction mask
                    cand_mask = torch.zeros_like(instr_mask, dtype=torch.bool)
                    cand_mask[Q2_in_kept_t] = instr_mask[Q2_in_kept_t]

                    cand_scores = scores_kept.masked_fill(~cand_mask, torch.finfo(scores_kept.dtype).min)
                    num_cands = int(cand_mask.sum().item())
                    n_drop = min(remaining_drop_for_Q2, num_cands)
                    if n_drop > 0:
                        flat_scores = cand_scores.reshape(-1)
                        topk = torch.topk(flat_scores, k=min(n_drop, flat_scores.numel()), largest=True)
                        drop_flat_idx = topk.indices
                        drop_mask = torch.zeros_like(flat_scores, dtype=torch.bool)
                        drop_mask[drop_flat_idx] = True
                        drop_mask = drop_mask.reshape(cand_scores.shape)
                        # build effective mask without in-place modifying tensors used in graph
                        effective_mask = base_valid_mask.clone()
                        effective_mask[drop_mask] = 0
                    else:
                        effective_mask = base_valid_mask
                else:
                    effective_mask = base_valid_mask
            else:
                effective_mask = base_valid_mask
            # Final selective CE over kept set only (apply effective mask)
            ce_token_loss_masked = ce_token_loss_raw * effective_mask
            ce_den_kept = effective_mask.sum().clamp_min(1)
            selective_loss = (ce_token_loss_masked.sum() / ce_den_kept)

            if return_outputs:
                return selective_loss, outputs
            return selective_loss


    @override
    def prediction_step(
        self,
        model: "torch.nn.Module",
        inputs: dict[str, Union["torch.Tensor", Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[list[str]] = None,
        **gen_kwargs,
    ) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
        r"""Remove the prompt part in the generated tokens.

        Subclass and override to inject custom behavior.
        """
        if self.args.predict_with_generate:  # do not pass labels to model when generate
            labels = inputs.pop("labels", None)
        else:
            labels = inputs.get("labels")

        loss, generated_tokens, _ = super().prediction_step(
            model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
        )
        if generated_tokens is not None and self.args.predict_with_generate:
            generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
            generated_tokens = generated_tokens.contiguous()

        return loss, generated_tokens, labels
