"""Token-level loss logger for Hugging Face / TRL trainers.

The callback computes the per-token negative log-likelihood on a small
validation batch every `logging_steps` and pushes a histogram of the worst
offenders (highest average loss) to Weights & Biases.  Useful for spotting
special tokens that never get tuned (<image_patch_*>, <|im_end|>, etc.).
"""

from __future__ import annotations

import math
import random
from collections import Counter, defaultdict
from typing import Any, Dict, List

import torch
from transformers import TrainerCallback, TrainerState, TrainerControl, TrainingArguments
from transformers.trainer_utils import PredictionOutput

try:
    import wandb
except ImportError:  # pragma: no cover – wandb optional for unit tests
    wandb = None  # type: ignore


class TokenLossLogger(TrainerCallback):
    """Compute and log a histogram of per-token losses.

    Parameters
    ----------
    tokenizer : transformers.PreTrainedTokenizer
        Tokenizer used to convert ids → string representation.
    eval_dataset : torch.utils.data.Dataset | None, default None
        Dataset to sample batches from.  If *None*, the trainer's
        evaluation dataset is used.
    batch_size : int, default 8
        Number of examples to evaluate per logging event.
    top_k : int, default 30
        Log the *top-k* tokens with the highest average loss.
    logging_steps : int, default 100
        Run once every *logging_steps*.
    """

    def __init__(
        self,
        tokenizer,
        eval_dataset=None,
        *,
        batch_size: int = 8,
        top_k: int = 30,
        logging_steps: int = 100,
		trainer=None,
    ) -> None:
        self.tokenizer = tokenizer
        self.eval_dataset = eval_dataset
        self.batch_size = batch_size
        self.top_k = top_k
        self.logging_steps = logging_steps
        self._step_counter = 0
        self.trainer = trainer

    # ---------------------- callback hooks ---------------------- #

    def on_step_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs: Any,
    ) -> None:  # noqa: D401 – HF signature
        self._step_counter += 1
        if self._step_counter % self.logging_steps != 0:
            return

        trainer = getattr(self, "trainer", None)
        if trainer is None:
            # Trainer sets this attribute via set_trainer; guard in case not yet set
            return

        dataset = self.eval_dataset or trainer.eval_dataset
        if dataset is None:
            return  # nothing to compute

        # ---------------- sample a batch ---------------- #
        indices = random.sample(range(len(dataset)), k=min(self.batch_size, len(dataset)))
        batch = [dataset[i] for i in indices]
        # Expect the trainer's data collator to handle conversion → tensors
        collated = trainer.data_collator(batch)  # type: ignore[arg-type]
        collated = {k: v.to(args.device) for k, v in collated.items()}

        with torch.no_grad():
            outputs = trainer.model(**collated)
            # Hugging Face causal-LMs return loss + logits when labels provided
            logits = outputs.logits  # (B, L, vocab)

        labels: torch.Tensor = collated["labels"]
        # Shift logits for next-token prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        per_tok_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        per_tok_loss = per_tok_loss.view(shift_labels.shape)  # (B, L)

        # Mask out ignored tokens (-100)
        mask = shift_labels.ne(-100)
        per_tok_loss = per_tok_loss * mask

        # Aggregate losses per token id
        token_sums: Dict[int, float] = defaultdict(float)
        token_counts: Dict[int, int] = Counter()

        ids = shift_labels[mask]  # flat tensor of valid ids
        losses = per_tok_loss[mask]
        for tid, lval in zip(ids.tolist(), losses.tolist()):
            token_sums[tid] += lval
            token_counts[tid] += 1

        # Compute average loss and convert to nats (assumes base-e log inside CE)
        avg_losses = {tid: token_sums[tid] / token_counts[tid] for tid in token_counts}
        # Get worst offenders
        worst = sorted(avg_losses.items(), key=lambda kv: kv[1], reverse=True)[: self.top_k]

        table_data: List[List[Any]] = []
        for tid, loss_val in worst:
            tok_str = self.tokenizer.convert_ids_to_tokens(tid)
            ppl = math.exp(loss_val)
            table_data.append([tok_str, tid, loss_val, ppl])

        if wandb is not None and wandb.run is not None:
            wandb.log({
                "token_loss/top_k": wandb.Table(
                    data=table_data,
                    columns=["token", "id", "nll", "perplexity"],
                )
            }, step=state.global_step)

        # also print top-5 to console for quick glance (only on rank-0 if distributed)
        if torch.distributed.is_initialized():
            if torch.distributed.get_rank() == 0:
                top_console = ", ".join(
                    f"{self.tokenizer.convert_ids_to_tokens(t)}: {l:.2f}" for t, l in worst[:5]
                )
                print(f"[TokenLossLogger] step {state.global_step}: {top_console}")
        else:
            top_console = ", ".join(
                f"{self.tokenizer.convert_ids_to_tokens(t)}: {l:.2f}" for t, l in worst[:5]
            )
            print(f"[TokenLossLogger] step {state.global_step}: {top_console}") 