from functools import wraps
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import logging
import os
from dataclasses import dataclass, field

import torch
import torch.nn as nn
from torch import Tensor
from transformers import PreTrainedModel, AutoModelForCausalLM, EvalPrediction
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from accelerate.utils import DistributedType
from trl import create_reference_model, SFTTrainer

try:
    from peft import PeftModel

    PEFT_AVAILABLE = True
except ImportError:
    PEFT_AVAILABLE = False


def is_peft_model(model):
    if not PEFT_AVAILABLE:
        return False
    return isinstance(model, PeftModel)


from redflag.configs import AdvAttackConfig, EMAConfig
from redflag.data_utils import get_pattern_positions
from redflag.utils import ManualLossAccumulator
from redflag.sft_trainer_utils import (
    SeqLossWeighting,
    kl_div_fn,
    negative_cross_entropy,
    fill_around_range,
    fill_except_indices_from_positions,
    insert_rf_logits,
)
from redflag.ema_utils import EMAUpdateCallback, EMAModel
import redflag.adversarial.utils as adv_utils


RF_CROSS_ENTROPY_MODES = {
    (XENT_RF_ONLY := "rf_only"),
    (XENT_UP_TO_RF := "up_to_rf"),
}
UTILITY_LOSS_MODES = {
    (UTIL_KL := "kl"),
}

DEFAULT_METRIC_KEYS = [
    "filtered_kl_redflag",
    "filtered_kl_xent",
    "rf_xent",
    "kl_redflag",
    "kl_ref",
    "total",
]

ADVERSARIAL_METRIC_KEYS = [
    (LOSS_FIRST := "loss_first"),
    (LOSS_LAST := "loss_last"),
    (LOSS_DROP := "loss_drop"),
    (AFFIRMATIVE_FIRST := "affirmative_first"),
    (AFFIRMATIVE_LAST := "affirmative_last"),
]


def tokenize(
    element,
    tokenizer,
    dataset_text_field: str,
    max_seq_length: int,
    add_special_tokens: bool = True,
    formatting_func: Callable | None = None,
):
    outputs = tokenizer(
        element[dataset_text_field] if formatting_func is None else formatting_func(element),
        add_special_tokens=add_special_tokens,
        truncation=True,
        padding=False,
        max_length=max_seq_length,
        return_overflowing_tokens=False,
        return_length=False,
    )

    if formatting_func is not None and not isinstance(formatting_func(element), list):
        raise ValueError(
            "The `formatting_func` should return a list of processed strings since it can lead to silent bugs."
        )

    return {
        "input_ids": outputs["input_ids"],
        "attention_mask": outputs["attention_mask"],
    }


class RedFlagComputeMetrics:
    def __init__(
        self,
        redflag_idx: int,
        response_idx: int | List[int],
        topk: int = 2,
        n_roc_auc: int = 10,
    ):
        """Computes metrics for redflag generation; checks if redflag is generated in response,
        and otherwise if redflag is in the topk logits.

        Args:
            redflag_idx (int): The index of the redflag token in the vocabulary
            response_idx (int | List[int]): The index of the response token(s) in the vocabulary,
                i.e. <|assistant|>
            topk (int, optional): Number of logits to look over. Defaults to 2.
        """
        self.redflag_idx = redflag_idx
        if isinstance(response_idx, int):
            response_idx = [response_idx]
        self.response_idx = response_idx
        self.topk = topk
        self.n_roc_auc = n_roc_auc

    def compute_metrics(self, p: EvalPrediction) -> Dict:
        metrics = {}
        device = "cpu"
        batch_size, seqlen = p.inputs.shape[:2]
        response_start_idx = get_pattern_positions(p.inputs, self.response_idx)
        logits = torch.tensor(p.predictions)
        probs = torch.softmax(logits, dim=-1)
        ## Why are you computing the probabilities here for the top k? This is not needed!

        # top-1/top-k accuracies
        mask = torch.zeros(probs.shape[:2], dtype=torch.bool)
        mask = fill_around_range(
            mask, fill_value=True, start_pos=torch.zeros_like(response_start_idx), end_pos=response_start_idx
        )
        mask[p.inputs == -100] = False  # filter out pad tokens

        is_redflag_pred = probs.argmax(dim=-1) == self.redflag_idx
        is_redflag_pred = is_redflag_pred & mask
        metrics["redflag_top1"] = torch.any(is_redflag_pred, dim=-1).sum().item() / batch_size

        topk = torch.topk(probs, self.topk, dim=-1)
        topk_idx, topk_probs = topk.indices, topk.values
        topk_idx = topk_idx == self.redflag_idx
        topk_idx = topk_idx & mask.unsqueeze(-1)
        metrics[f"redflag_top{self.topk}"] = topk_idx.any(-1).any(-1).sum().item() / batch_size

        # simply the aggregate probability that it will be generated over the sequence
        # formula for probability = 1 - (\prod_i (1-p_i))
        gen_p = 1 - torch.exp((mask.float() * torch.log(1 - probs[:, :, self.redflag_idx])).sum(dim=-1))
        metrics["rf_gen_proba"] = gen_p.sum().item() / batch_size

        # aggregate probability 8 tokens after generation start
        mask_start_position = torch.argmax(
            torch.cumsum(1 - mask.int(), dim=-1), dim=-1
        )  # TODO: duplicate of response_start_idx?
        rev_mask_8 = torch.arange(seqlen, device=device).unsqueeze(0)
        rev_mask_8 = rev_mask_8.expand(batch_size, seqlen) < (mask_start_position + 8).unsqueeze(1)
        gen_p_post_8 = 1 - torch.exp(
            ((1 - rev_mask_8.float()) * torch.log(1 - probs[:, :, self.redflag_idx])).sum(dim=-1)
        )
        metrics["rf_gen_proba_ignoring_8_tokens"] = gen_p_post_8.sum().item() / batch_size

        # ... 16 tokens after generation start
        rev_mask_16 = torch.arange(seqlen, device=device).unsqueeze(0)
        rev_mask_16 = rev_mask_16.expand(batch_size, seqlen) < (mask_start_position + 16).unsqueeze(1)
        gen_p_post_16 = 1 - torch.exp(
            ((1 - rev_mask_16.float()) * torch.log(1 - probs[:, :, self.redflag_idx])).sum(dim=-1)
        )
        metrics["rf_gen_proba_ignoring_16_tokens"] = gen_p_post_16.sum().item() / batch_size

        return metrics


@dataclass
class AttackResult:
    attack_embeddings: Tensor
    misc_metrics: Dict = field(default_factory=dict)
    nan_restarts_count: int = 0

    def format_log(self, step: int):
        _mapper = lambda x, y: {"value": x, "steps": y}
        return {key: _mapper(value, step) for key, value in self.misc_metrics.items()}


class RedFlagTrainer(SFTTrainer):
    TRAINING_METRIC_KEYS = DEFAULT_METRIC_KEYS

    @wraps(SFTTrainer.__init__)
    def __init__(
        self,
        *,  # forces all kwargs to be named
        ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        utility_loss_mode: str = "kl",
        rf_xent_mode: str = "rf_only",
        alpha_rf_xent: float = 1.0,
        alpha_kl_redflag: float = 1.0,
        alpha_kl_ref: float = 1.0,
        rf_xent_cutoff: float = 0.5,
        alpha_away_rf: float = 0.0,
        away_rf_cutoff: float = -5.0,
        ref_model_init_kwargs: Optional[Dict] = None,
        rf_token_id=128255,  # llama3
        use_base_model_as_ref: bool = True,
        copy_base_model_as_ref: bool = False,
        kl_fix: bool = True,
        drop_prompt_attn_mask_prob: float = 0.0,
        metrics_store: Optional[ManualLossAccumulator] = None,
        kl_weighting: Optional[SeqLossWeighting] = None,
        xent_weighting: Optional[SeqLossWeighting] = None,
        adv_attack: Optional[AdvAttackConfig] = None,
        ema_config: Optional[EMAConfig] = None,
        **kwargs,
    ):
        """
        Args:
            ref_model: reference model to compute KL divergence with
            utility_loss_mode: utility loss mode for the model and the reference model
            rf_xent_mode: redflag cross-entropy mode; either exclusively the rf token or all the tokens up to the rf token
            alpha_rf_xent: weight for the redflag cross-entropy loss
            alpha_kl_redflag: weight for the KL divergence between the post redflag distribution and the reference model
            alpha_kl_ref: weight for the KL divergence between the model and the reference model
            rf_xent_cutoff: loss below this value is scaled down by 0.001
            alpha_away_rf: weight for the loss pushing away from the redflag token
            away_rf_cutoff: cutoff threshold for the away loss
            ref_model_init_kwargs: keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference
            rf_token_id: token ID of the redflag token
            use_base_model_as_ref: whether to use the same model with adapters disabled as reference model (saves memory)
            drop_prompt_attn_mask_prob: probability of dropping the attention mask for the prompt tokens
            kl_weighting: weighting function for the KL divergence
            xent_weighting: weighting function for the redflag cross-entropy loss
            ema_config: EMA configuration object
        """
        if not isinstance(kwargs.get("model"), str) and ref_model is kwargs.get("model"):
            raise ValueError(
                "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
                "same as `model`, you must mass a copy of it, or `None` if you use peft."
            )
        super().__init__(**kwargs)
        self.left_pad = self.tokenizer.padding_side == "left"
        self.rf_token_id = rf_token_id
        self.drop_prompt_attn_mask_prob = drop_prompt_attn_mask_prob

        if utility_loss_mode not in UTILITY_LOSS_MODES:
            raise ValueError(f"utility_loss_mode must be one of {UTILITY_LOSS_MODES} - got {utility_loss_mode}")
        if rf_xent_mode not in RF_CROSS_ENTROPY_MODES:
            raise ValueError(f"rf_xent_mode must be one of {RF_CROSS_ENTROPY_MODES} - got {rf_xent_mode}")

        self.utility_loss_mode = utility_loss_mode
        self.rf_xent_mode = rf_xent_mode

        self.alpha_rf_xent = alpha_rf_xent
        self.alpha_kl_redflag = alpha_kl_redflag
        self.alpha_kl_ref = alpha_kl_ref
        self.alpha_away_rf = alpha_away_rf
        self.away_rf_cutoff = away_rf_cutoff
        self.rf_xent_cutoff = rf_xent_cutoff
        self.use_base_model_as_ref = use_base_model_as_ref
        self.copy_base_model_as_ref = copy_base_model_as_ref
        self.kl_fix = kl_fix
        self._ref_model_peft_mode = False
        self._init_ref_model(ref_model, ref_model_init_kwargs)
        self.kl_weighting = kl_weighting
        self.xent_weighting = xent_weighting

        _vocab_size = getattr(self.model.config, "vocab_size", None)
        if _vocab_size is None:
            _vocab_size = self.model.get_input_embeddings().weight.shape[0]
        self.vocab_size = _vocab_size

        if self.label_smoother is not None:
            raise NotImplementedError("Label smoothing is not supported for RedFlagTrainer")

        # Loss accumulation that works with gradient accumulation
        if metrics_store is not None:
            self.metrics_store = metrics_store
        else:
            self.metrics_store = ManualLossAccumulator(self.TRAINING_METRIC_KEYS)

        # Adversarial attack stuff
        self.adv_attack_config = AdvAttackConfig(**adv_attack) if isinstance(adv_attack, dict) else adv_attack
        self.affirmative_last_list = []

        self.metrics_store_adv = None
        if self.adv_attack_config is not None:
            self.metrics_store_adv = ManualLossAccumulator(ADVERSARIAL_METRIC_KEYS)

        # EMA initialization
        self.ema_config = EMAConfig(**ema_config) if isinstance(ema_config, dict) else ema_config
        self.use_ema = self.ema_config is not None and self.ema_config.use_ema
        self.ema_model = None
        if self.use_ema:
            self._init_ema()

    def _init_ref_model(
        self,
        ref_model: Optional[Union[PreTrainedModel, nn.Module, str]],
        ref_model_init_kwargs: Optional[Dict] = None,
    ):
        if ref_model_init_kwargs is not None:
            logging.warning(
                "You passed `ref_model_init_kwargs` to the RedFlagTrainer, the value you passed will override the one in the `RedFlagConfig`."
            )
            self.args.ref_model_init_kwargs = ref_model_init_kwargs

        # Setup ref_model_init_kwargs
        if self.args.ref_model_init_kwargs is None:
            ref_model_init_kwargs = {}
        elif not isinstance(ref_model, str):
            raise ValueError(
                "You passed ref_model_init_kwargs to the RedFlagTrainer/RedFlagConfig, but your ref_model is already instantiated."
            )
        else:
            ref_model_init_kwargs = self.args.ref_model_init_kwargs
            torch_dtype = ref_model_init_kwargs.get("torch_dtype")
            if torch_dtype is not None:
                # Convert to `torch.dtype` if an str is passed
                if isinstance(torch_dtype, str) and torch_dtype != "auto":
                    torch_dtype = getattr(torch, torch_dtype)
                if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
                    raise ValueError(
                        f"Invalid `torch_dtype` passed to the RedFlagConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
                    )
                ref_model_init_kwargs["torch_dtype"] = torch_dtype

        # Init reference model
        # case 1: model id provided; initialize it via AutoModelForCausalLM
        if isinstance(ref_model, str):
            logging.warning(
                "You passed a ref model_id to the RedFlagTrainer. This will automatically create an "
                "`AutoModelForCausalLM`"
            )
            ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
            self.ref_model = ref_model

        # case 2: model explicitly provided
        elif ref_model:
            logging.warning("You directly passed a ref_model to the RedFlagTrainer.")
            self.ref_model = ref_model

        # case 3: use the base model as the reference model
        elif self.args.use_base_model_as_ref:
            # case 3b) copy the base model as the reference model
            if self.args.copy_base_model_as_ref:
                logging.warning("Copying the base model as the reference model.")
                self.ref_model = create_reference_model(self.model)

            # case 3a) enable/disable peft adapters during training
            elif is_peft_model(self.model):
                logging.warning(
                    "Using the online model with adapters disabled as the reference model. "
                    "This will save memory but will require toggling adapters on/off during training."
                )
                self.ref_model = None
                self._ref_model_peft_mode = True

            else:
                raise ValueError("You must either use the base model as the reference model or copy it.")
        else:
            raise ValueError("No reference model created with the provided arguments.")
        
    def _init_ema(self):
        self.ema_model = EMAModel(
            model=self.model,
            decay=self.ema_config.ema_decay,
            min_decay=self.ema_config.ema_min_decay,
            update_after_step=self.ema_config.ema_update_after_step,
            use_ema_warmup=self.ema_config.ema_use_warmup,
            inv_gamma=self.ema_config.ema_inv_gamma,
            power=self.ema_config.ema_power,
        )
        logging.info(f"EMA initialized with decay={self.ema_config.ema_decay}, min_decay={self.ema_config.ema_min_decay}")
        # Validate that training is not sharded in an unsupported way
        try:
            dist_type = getattr(self, "accelerator", None)
            dist_type = dist_type.distributed_type if dist_type is not None else None
            zero_stage = None
            if dist_type == DistributedType.DEEPSPEED:
                plugin = getattr(self.accelerator.state, "deepspeed_plugin", None)
                zero_stage = getattr(plugin, "zero_stage", None)
            if dist_type in {DistributedType.FSDP} or (dist_type == DistributedType.DEEPSPEED and zero_stage == 3):
                raise NotImplementedError(
                    "EMA is not supported with FSDP or DeepSpeed ZeRO-3 in this trainer. Disable EMA or use DDP/DeepSpeed ZeRO-1/2."
                )
        except Exception as e:
            logging.warning(f"Could not fully determine sharding mode for EMA: {e}")

        self.add_callback(
            EMAUpdateCallback(
                self.ema_model,
                should_update_fn=lambda: not getattr(self.accelerator, "optimizer_step_was_skipped", False),
            )
        )

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        rf_positions = inputs["rf_positions_post"]  # important for multi, we must account for prev inserted redflags
        rf_entries = inputs["rf_entries"]
        num_rf_entries = rf_entries.sum().item()

        # prepare redflag labels according to redflag cross-entropy strategy
        if self.data_collator.insert_sampler.version == 1:
            # backwards compatibility
            if self.rf_xent_mode == XENT_RF_ONLY:
                inputs["rf_labels"] = fill_around_range(
                    inputs["rf_labels"],
                    fill_value=self.data_collator.ignore_index,
                    start_pos=rf_positions,
                )

            elif self.rf_xent_mode == XENT_UP_TO_RF:
                start_pos = inputs["response_start_idx_offset"]
                inputs["rf_labels"] = fill_around_range(
                    inputs["rf_labels"],
                    fill_value=self.data_collator.ignore_index,
                    start_pos=start_pos,
                    end_pos=rf_positions,
                )
                inputs["rf_labels"][
                    inputs["rf_labels"] != self.data_collator.ignore_index
                ] = self.data_collator.rf_token_id
        else:
            # multi token inserts
            if self.rf_xent_mode == XENT_RF_ONLY:
                inputs["rf_labels"] = fill_except_indices_from_positions(
                    inputs["rf_labels"],
                    fill_value=self.data_collator.ignore_index,
                    keep_positions=rf_positions,
                )

            elif self.rf_xent_mode == XENT_UP_TO_RF:
                # TODO: for now, keep rf x-ent up to the last rf
                start_pos = inputs["response_start_idx_offset"]
                inputs["rf_labels"] = fill_around_range(
                    inputs["rf_labels"],
                    fill_value=self.data_collator.ignore_index,
                    start_pos=start_pos,
                    end_pos=rf_positions[:, -1],
                )
                inputs["rf_labels"][
                    inputs["rf_labels"] != self.data_collator.ignore_index
                ] = self.data_collator.rf_token_id  # keep in mind this does not ???

        rf_num_items_in_batch = (inputs["rf_labels"] != self.data_collator.ignore_index).sum().item()

        # compute the correct number entries and entries with redflags in the batch
        rf_loss_kwargs = {}
        if self.model_accepts_loss_kwargs:
            if num_items_in_batch is not None:
                rf_loss_kwargs["num_items_in_batch"] = rf_num_items_in_batch

        # compute reference model logits
        ref_inputs = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "position_ids": inputs["position_ids"] if self.left_pad else None,
        }
        with torch.no_grad():
            # If we're using a PEFT model with adapters disabled as reference,
            # we need to disable the adapters while computing reference outputs
            if self.ref_model is not None:
                ref_outputs = self.ref_model(**ref_inputs)
            elif self.ref_model is None and self._ref_model_peft_mode:
                # Unwrap DDP/DataParallel model to access PEFT methods
                unwrapped_model = self.model.module if hasattr(self.model, "module") else self.model
                with unwrapped_model.disable_adapter():
                    ref_outputs = unwrapped_model(**ref_inputs)
            else:
                raise NotImplementedError("No reference model created with the provided arguments.")
        
        # drop attention mask over prompt with some probability during training
        if self.drop_prompt_attn_mask_prob > 0 and model.training and rf_entries.any():
            _, seq_len = inputs["input_ids"].shape
            device = inputs["input_ids"].device
            attn_mask = inputs["attention_mask"]
            drop_mask = torch.bernoulli(torch.full((num_rf_entries,), self.drop_prompt_attn_mask_prob, device=device)).bool()
            harm_attn_mask_dropped = torch.arange(seq_len, device=device).unsqueeze(0).expand(num_rf_entries, seq_len)
            harm_attn_mask_dropped = (harm_attn_mask_dropped > inputs["response_start_idx"][rf_entries, None]).int()
            if not drop_mask.any():
                harm_attn_mask_dropped[~drop_mask] = 1  # if we don't drop, set to all ones so we retain prompt
            harm_attn_mask_dropped = harm_attn_mask_dropped * attn_mask[rf_entries]  # AND between original mask and dropping the attention on the prompt
            attn_mask[rf_entries] = harm_attn_mask_dropped

        clean_inputs = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "position_ids": inputs["position_ids"] if self.left_pad else None,
        }
        clean_outputs = model(**clean_inputs)

        # KL term for utility consistency
        ref_logits = ref_outputs.logits
        clean_logits = clean_outputs.logits
        utility_mask = inputs["labels"] != self.data_collator.ignore_index
        if self.kl_fix:
            kl_utility = kl_div_fn(
                logits_a=clean_logits,
                logits_b=ref_logits,
                mask=utility_mask,
                op_dtype=None,
                filter_rf_entries=~rf_entries,
            )
        else:
            kl_utility = kl_div_fn(logits_a=clean_logits, logits_b=ref_logits, mask=utility_mask, op_dtype=None)

        # compute SFT model redflag logits + cross-entropy loss on the redflag positions
        attack_result = None
        if torch.any(rf_entries):
            rf_attn_mask = inputs["rf_attention_mask"]
            rf_inputs = {
                "attention_mask": rf_attn_mask,
                "labels": inputs["rf_labels"],
                "position_ids": inputs["rf_position_ids"] if self.left_pad else None,
                **rf_loss_kwargs,
            }
            if self.adv_attack_config is not None and model.training:
                attack_result = self.get_adversarial_attack_embeddings(model, inputs)
                rf_inputs["inputs_embeds"] = attack_result.attack_embeddings
            else:
                rf_inputs["input_ids"] = inputs["rf_input_ids"]

            rf_outputs = model(**rf_inputs)

            # cross entropy losses; 0 if no redflag entries.
            if self.xent_weighting is None:
                rf_xent = rf_outputs.loss if torch.any(rf_entries) else 0.0
            else:
                xent_weighting = self.xent_weighting.get_weighting(
                    shape=inputs["rf_labels"].shape[:2],
                    insert_positions=rf_positions,
                    skip_id=self.data_collator.ignore_index,
                )
                drop_mask = inputs["drop_mask"]
                xent_weighting[drop_mask] = self.xent_weighting.drop_weight  # dropped rows should have full weighting

                if torch.any(rf_entries):
                    rf_xent = negative_cross_entropy(
                        rf_outputs["logits"],
                        inputs["rf_labels"],
                        self.vocab_size,
                        weights=xent_weighting,
                        lower_threshold=None,
                        upper_threshold=None,
                        num_items_in_batch=rf_num_items_in_batch,
                        ignore_index=self.data_collator.ignore_index,
                    )
                else:
                    rf_xent = 0.0

            # Apply loss cutoff if below threshold
            if rf_xent < self.rf_xent_cutoff:
                rf_xent = self.rf_xent_cutoff + 0.001 * rf_xent

            # KL term for redflag vs reference
            if len(rf_positions.shape) == 2:
                # multiple redflag KL computation more tricky
                # 1. remove redflag positions; we want to match the base model logits shape and not backprop KL loss to redflags
                target_logits, output_logits = insert_rf_logits(
                    ref_outputs.logits, rf_outputs.logits, rf_positions, rf_entries
                )

                # 2. mask before 1st redflag position
                row_indices = torch.arange(output_logits.shape[1], device=rf_positions.device)
                valid_mask = rf_positions[rf_entries] != self.data_collator.ignore_index
                first_rf_pos = (
                    torch.where(
                        valid_mask, rf_positions[rf_entries], torch.tensor(float("inf"), device=rf_positions.device)
                    )
                    .min(dim=1)[0]
                    .unsqueeze(1)
                )
                mask = row_indices >= first_rf_pos
                min_rf_pos = first_rf_pos.min().int()

                # 3. get weighting
                weighting = self.kl_weighting.get_weighting(
                    shape=output_logits.shape[:2],
                    insert_positions=rf_positions[rf_entries],
                    skip_id=self.data_collator.ignore_index,
                )

                # 4. get KL, filter up to first redflag position, apply weighting
                kl_redflag = kl_div_fn(
                    logits_a=output_logits[:, min_rf_pos:],
                    logits_b=target_logits[:, min_rf_pos:],
                    mask=mask[:, min_rf_pos:],
                    op_dtype=None,
                    pre_reduce_scale=weighting[:, min_rf_pos:],
                )

            else:
                # single redflag KL computation
                min_rf_pos = rf_positions[rf_entries].min().item()
                target_logits = ref_outputs.logits[rf_entries, min_rf_pos - 1 :]  # clean outputs
                output_logits = rf_outputs.logits[rf_entries, min_rf_pos:]  # redflag outputs
                # create mask for tokens before redflag positions
                row_indices = torch.arange(output_logits.shape[1], device=rf_positions.device)
                mask = row_indices >= (rf_positions[rf_entries].unsqueeze(1) - min_rf_pos)
                kl_redflag = kl_div_fn(logits_a=output_logits, logits_b=target_logits, mask=mask, op_dtype=None)
        else:
            # fallback if batch contained no redflag inputs, thus nothing to compute
            rf_xent = 0.0
            kl_redflag = 0.0

        # TODO: possibly scale each loss by token count per loss term?
        loss = self.alpha_rf_xent * rf_xent + self.alpha_kl_redflag * kl_redflag + self.alpha_kl_ref * kl_utility

        rf_xent = rf_xent.item() if isinstance(rf_xent, torch.Tensor) else rf_xent
        kl_redflag = kl_redflag.item() if isinstance(kl_redflag, torch.Tensor) else kl_redflag
        kl_utility = kl_utility.item() if isinstance(kl_utility, torch.Tensor) else kl_utility

        _mapper = lambda x, y: {"value": x, "steps": y}
        self.metrics_store.update(
            rf_xent=_mapper(rf_xent, 1),
            kl_redflag=_mapper(kl_redflag, 1),
            kl_ref=_mapper(kl_utility, 1),
            total=_mapper(loss.item(), 1),
            # handle cases where the batch contains no redflag entries
            filtered_kl_redflag=_mapper(kl_redflag, min(1, num_rf_entries)),
            filtered_kl_xent=_mapper(rf_xent, min(1, num_rf_entries)),
        )

        if self.metrics_store_adv is not None and attack_result is not None:
            self.metrics_store_adv.update(**attack_result.format_log(min(1, num_rf_entries)))

        if return_outputs:
            outputs = clean_outputs
            return loss, outputs

        return loss

    def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
    ) -> torch.Tensor:
        loss = super().training_step(model, inputs, num_items_in_batch)
        return loss

    def log(self, logs: dict[str, float], *args) -> None:
        if self.control.should_log and self.metrics_store and self.metrics_store.loss_steps["total"] > 0:
            avg_losses = self.metrics_store.compute_loss(prefix="losses/")
            if self.metrics_store_adv is not None:
                avg_losses.update(self.metrics_store_adv.compute_loss(prefix="attack/"))
            logs.update(avg_losses)
            self.metrics_store.finalize_and_reset()  # clear metrics

        super().log(logs, *args)

    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        """Load EMA state from checkpoint if available."""
        super()._load_from_checkpoint(resume_from_checkpoint, model)
        
        if self.use_ema and self.ema_model is not None:
            checkpoint_path = resume_from_checkpoint
            if isinstance(resume_from_checkpoint, bool):
                checkpoint_path = self.state.best_model_checkpoint or self.args.output_dir
            
            ema_checkpoint_path = f"{checkpoint_path}/ema_state.pt"
            if os.path.exists(ema_checkpoint_path):
                ema_state = torch.load(ema_checkpoint_path, map_location="cpu")
                
                # Get the device from the model
                # Handle wrapped models (e.g., by Accelerate)
                model_to_check = model if model is not None else self.model
                if hasattr(model_to_check, "module"):  # Handle DeepSpeed/DataParallel
                    model_to_check = model_to_check.module
                # Get device from first trainable parameter, fallback to first parameter if none are trainable
                try:
                    device = next(p.device for p in model_to_check.parameters() if p.requires_grad)
                except StopIteration:
                    device = next(model_to_check.parameters()).device
                
                self.ema_model.load_state_dict(ema_state, device=device)
                logging.info(f"Loaded EMA state from {ema_checkpoint_path}")
            else:
                logging.warning(f"EMA checkpoint not found at {ema_checkpoint_path}, starting fresh")

    def _save_checkpoint(self, model, trial):
        """Save EMA state along with regular checkpoint."""
        # Call parent to save regular checkpoint
        super()._save_checkpoint(model, trial)
        
        if self.use_ema and self.ema_model is not None:
            # Construct the checkpoint path the same way the parent class does
            checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
            run_dir = self._get_output_dir(trial=trial)
            output_dir = os.path.join(run_dir, checkpoint_folder)
            
            ema_checkpoint_path = f"{output_dir}/ema_state.pt"
            torch.save(self.ema_model.state_dict(), ema_checkpoint_path)
            logging.info(f"Saved EMA state to {ema_checkpoint_path}")
        
    def maybe_save_ema_model(self, output_dir: Optional[str] = None):
        if self.use_ema and self.ema_model is not None:
            logging.info(f"Saving final model with EMA adapters to {output_dir}")
            with self.ema_model.ema_loaded(self.model):
                super().save_model(output_dir)
        else:
            logging.info(f"No EMA model to save, saving final model to {output_dir}")
            super().save_model(output_dir)
            
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        """Evaluate using EMA parameters if available."""
        if self.use_ema and self.ema_model is not None:
            with self.ema_model.ema_loaded(self.model):
                eval_result = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        else:
            eval_result = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        return eval_result

    def get_adversarial_attack_embeddings(self, model, inputs):
        # unpack and select only redflag entries
        rf_entries = inputs["rf_entries"]
        device = inputs["rf_input_ids"].device

        user_start_idx = inputs["user_start_idx"][rf_entries]
        user_end_idx = inputs["user_end_idx"][rf_entries]
        response_start_idx = inputs["response_start_idx"][rf_entries]

        adv_raw_input_ids = inputs["adv_raw_input_ids"][rf_entries]
        adv_raw_labels = inputs["adv_raw_labels"][rf_entries]
        adv_raw_attn_mask = inputs["adv_raw_attn_mask"][rf_entries]

        # attack; get embeddings on user input
        _idx_adv = torch.arange(adv_raw_input_ids.shape[1], device=device).unsqueeze(0)
        adv_mask = (_idx_adv >= user_start_idx.unsqueeze(1)) & (_idx_adv < user_end_idx.unsqueeze(1))
        
        # Unwrap DDP model for adversarial attack - attack does multiple backward passes
        # that would interfere with DDP's gradient hooks if using wrapped model
        unwrapped_model = model.module if hasattr(model, 'module') else model
        
        adv_attack_result = adv_utils.adversarial_attack(
            unwrapped_model,
            adv_raw_input_ids,
            adv_raw_labels,
            adv_raw_attn_mask,
            self.processing_class,  # tokenizer
            self.adv_attack_config,
            adv_mask,
            maximize_loss_idx=self.rf_token_id,
        )
        adv_attack_input_embeds = adv_attack_result.get_adv_embeddings().detach()

        # get the mask for the adversarial attack on user input
        _idx_adv = torch.arange(adv_raw_input_ids.shape[1], device=device).unsqueeze(0)
        adv_target_mask = _idx_adv < user_end_idx.unsqueeze(1)
        adv_user_mask = (adv_target_mask * adv_raw_attn_mask).to(bool)

        # get mask for where the adversarial embeddings will be inserted; only on batch rows with RF
        _idx = torch.arange(inputs["rf_input_ids"].shape[1], device=device).unsqueeze(0)
        _before_response_mask = _idx < user_end_idx.unsqueeze(1)
        dest_user_mask = (_before_response_mask * inputs["rf_attention_mask"][rf_entries]).to(bool)

        # get the clean embeddings for the redflag positions
        inputs_embeds_rf = adv_utils.get_clean_embeddings(model, inputs["rf_input_ids"]).detach()

        # insert adversarial embeddings into prompt
        adv_inputs_embeds_rf = adv_utils.transfer_masked_embeddings(
            embeds_input=adv_attack_input_embeds,
            mask_input=adv_user_mask,
            embeds_dest=inputs_embeds_rf[rf_entries],
            mask_dest=dest_user_mask,
        )
        inputs_embeds_rf[rf_entries] = adv_inputs_embeds_rf.clone()

        attack_metrics = {
            LOSS_FIRST: adv_attack_result.all_losses[0],
            LOSS_LAST: adv_attack_result.all_losses[-1],
            LOSS_DROP: adv_attack_result.all_losses[0] - adv_attack_result.all_losses[-1],
            AFFIRMATIVE_FIRST: adv_attack_result.affirmative_responses[:, 0].mean().item(),
            AFFIRMATIVE_LAST: adv_attack_result.affirmative_responses[:, -1].mean().item(),
        }

        return AttackResult(
            attack_embeddings=inputs_embeds_rf,
            misc_metrics=attack_metrics,
            nan_restarts_count=adv_attack_result.nan_restarts_count,
        )
