import contextlib
import json
import logging as py_logging
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import deepspeed
from transformers import Trainer
from transformers.trainer_pt_utils import nested_detach
from transformers.utils import is_sagemaker_mp_enabled
from transformers.trainer import *
from transformers.integrations import is_deepspeed_zero3_enabled

logger = py_logging.getLogger(__name__)


@dataclass
class GradientProbeConfig:
    enable_grad_probe: bool = False
    grad_probe_log_every: int = 10
    grad_probe_steps: Optional[str] = None
    grad_probe_output_dir: Optional[str] = None
    grad_probe_param_prefix: str = "vision_model"
    grad_probe_save_dtype: str = "float32"
    grad_probe_window_size: int = 20
    grad_probe_skiplink_index: Optional[int] = None
    grad_probe_only_skiplink_layer: bool = False
    one_step_enable: bool = False
    one_step_every: int = 100
    one_step_output_dir: Optional[str] = None
    # Optional: persist per-step gradient payloads; can be heavy
    save_grad_payload_files: bool = False
    # If False, delete saved payload files when a window is finalized
    keep_grad_payload_files: bool = False


class CPMTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        if "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        
        #if not self.args.use_lora:
        outputs = self.model(data = inputs, use_cache=False)
        #else:
        #    with self.model._enable_peft_forward_hooks(**inputs):
        #        outputs = self.model.base_model(data = inputs, use_cache=False)
                
        if labels is not None:
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            logits = outputs.logits.view(-1,
                                         self.model.config.vocab_size).contiguous()
            labels = labels.view(-1).long().contiguous()
            # Enable model parallelism
            labels = labels.to(logits.device)
            loss = loss_fct(logits, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
            ignore_keys (`List[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.

        Return:
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
        """
        has_labels = (
            False
            if len(self.label_names) == 0
            else all(inputs.get(k) is not None for k in self.label_names)
        )
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = (
            True if len(self.label_names) == 0 and return_loss else False
        )

        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(
                    self.model.config, "keys_to_ignore_at_inference", []
                )
            else:
                ignore_keys = []

        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels or loss_without_labels:
            labels = nested_detach(tuple(inputs.get(name)
                                   for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        with torch.no_grad():
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
                if has_labels or loss_without_labels:
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(
                            v
                            for k, v in raw_outputs.items()
                            if k not in ignore_keys + ["loss"]
                        )
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
                else:
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(
                            v for k, v in raw_outputs.items() if k not in ignore_keys
                        )
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
            else:
                if has_labels or loss_without_labels:
                    with self.compute_loss_context_manager():
                        loss, outputs = self.compute_loss(
                            model, inputs, return_outputs=True
                        )
                    loss = loss.mean().detach()

                    if isinstance(outputs, dict):
                        logits = tuple(
                            v
                            for k, v in outputs.items()
                            if k not in ignore_keys + ["loss"]
                        )
                    else:
                        logits = outputs[1:]
                else:
                    loss = None
                    with self.compute_loss_context_manager():
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(
                            v for k, v in outputs.items() if k not in ignore_keys
                        )
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
        
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        del inputs
        torch.cuda.empty_cache()

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss)

        return loss.detach() / self.args.gradient_accumulation_steps
    
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")

        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, supported_classes):
            if state_dict is None:
                state_dict = self.model.state_dict()

            if isinstance(unwrap_model(self.model), supported_classes):
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
                if self.args.save_safetensors:
                    safetensors.torch.save_file(
                        state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
                    )
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))


class GradProbeTrainer(Trainer):
    def __init__(self, *args, grad_probe_args: Optional[GradientProbeConfig] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.grad_probe_args = grad_probe_args or GradientProbeConfig()
        self._grad_probe_steps_set = self._parse_steps(self.grad_probe_args.grad_probe_steps)
        self._grad_probe_last_step = None
        self._grad_probe_param_names = None
        self._grad_probe_params = None
        self._grad_probe_metrics_path = None
        self._grad_probe_window_path = None
        self._one_step_path = None
        self._window_count = 0
        self._window_sum_gm = None
        self._window_sum_gs = None
        self._window_sum_gm_norm2 = 0.0
        self._window_sum_gs_norm2 = 0.0
        self._window_sum_gf_norm2 = 0.0
        self._window_sum_dot_ms = 0.0
        # Track per-step saved payload files for cleanup after window aggregation
        self._window_files: List[str] = []

    def _parse_steps(self, steps: Optional[str]) -> Optional[set]:
        if steps is None or steps.strip() == "":
            return None
        return {int(s.strip()) for s in steps.split(",") if s.strip() != ""}

    def _should_probe(self) -> bool:
        if not self.grad_probe_args.enable_grad_probe:
            return False
        if self._grad_probe_last_step == self.state.global_step:
            return False
        if self._grad_probe_steps_set is not None:
            return self.state.global_step in self._grad_probe_steps_set
        log_every = max(int(self.grad_probe_args.grad_probe_log_every), 1)
        return (self.state.global_step % log_every) == 0

    def _get_param_prefixes(self) -> List[str]:
        prefix = self.grad_probe_args.grad_probe_param_prefix
        if prefix is None or str(prefix).strip() == "":
            return []
        return [p.strip() for p in str(prefix).split(",") if p.strip() != ""]

    def _resolve_skiplink_param_prefixes(self) -> Optional[List[str]]:
        if not self.grad_probe_args.grad_probe_only_skiplink_layer:
            return None
        skip_idx = self.grad_probe_args.grad_probe_skiplink_index
        if skip_idx is None:
            return None
        if not hasattr(self.model, "vision_model"):
            return None
        vision_model = self.model.vision_model
        if not hasattr(vision_model, "layers") or not hasattr(vision_model, "transformer"):
            return None
        if not hasattr(vision_model.transformer, "resblocks"):
            return None
        total_layers = int(vision_model.layers)
        if skip_idx < 0:
            positive_layer = total_layers + skip_idx + 1
        else:
            positive_layer = skip_idx
        if positive_layer < 1 or positive_layer > total_layers:
            return None
        resblock_index = positive_layer - 1
        return [f"vision_model.transformer.resblocks.{resblock_index}"]

    def _init_grad_probe_params(self):
        prefixes = self._resolve_skiplink_param_prefixes()
        if prefixes is None:
            prefixes = self._get_param_prefixes()

        if self.grad_probe_args.grad_probe_only_skiplink_layer:
            skip_idx = self.grad_probe_args.grad_probe_skiplink_index
            model_skip = getattr(self.model, "skiplink_layers", None)
            if skip_idx is not None and model_skip is not None and skip_idx not in model_skip:
                logger.warning(
                    "grad_probe_skiplink_index=%s not in model.skiplink_layers=%s",
                    skip_idx,
                    model_skip,
                )
        param_names = []
        params = []
        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue
            if prefixes:
                if not any(name.startswith(p) for p in prefixes):
                    continue
            param_names.append(name)
            params.append(param)
        self._grad_probe_param_names = param_names
        self._grad_probe_params = params

        if self._grad_probe_metrics_path is None:
            output_dir = self.grad_probe_args.grad_probe_output_dir
            if output_dir is None:
                output_dir = os.path.join(self.args.output_dir, "grad_probe")
            os.makedirs(output_dir, exist_ok=True)
            self._grad_probe_metrics_path = os.path.join(output_dir, "grad_probe_metrics.jsonl")
            self._grad_probe_window_path = os.path.join(output_dir, "grad_probe_window.jsonl")
            one_step_dir = self.grad_probe_args.one_step_output_dir or output_dir
            os.makedirs(one_step_dir, exist_ok=True)
            self._one_step_path = os.path.join(one_step_dir, "one_step_loss.jsonl")

    def _get_rng_state(self):
        cpu_state = torch.random.get_rng_state()
        cuda_states = None
        if torch.cuda.is_available():
            cuda_states = torch.cuda.get_rng_state_all()
        return cpu_state, cuda_states

    def _set_rng_state(self, state):
        cpu_state, cuda_states = state
        torch.random.set_rng_state(cpu_state)
        if cuda_states is not None and torch.cuda.is_available():
            torch.cuda.set_rng_state_all(cuda_states)

    def _get_save_dtype(self):
        dtype = str(self.grad_probe_args.grad_probe_save_dtype).lower()
        if dtype == "float16":
            return torch.float16
        if dtype == "bfloat16":
            return torch.bfloat16
        return torch.float32

    def _replace_none_grads(self, grads, params):
        filled = []
        missing = []
        for grad, param in zip(grads, params):
            if grad is None:
                filled.append(torch.zeros_like(param))
                missing.append(True)
            else:
                filled.append(grad)
                missing.append(False)
        return filled, missing

    def _sum_dot(self, grads_a, grads_b) -> float:
        total = 0.0
        for ga, gb in zip(grads_a, grads_b):
            total += float((ga.double() * gb.double()).sum().item())
        return total

    def _all_reduce_scalar(self, value: float) -> float:
        if not torch.distributed.is_available() or not torch.distributed.is_initialized():
            return value
        tensor = torch.tensor(value, device=self.args.device, dtype=torch.float64)
        torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
        return float(tensor.item())

    def _all_reduce_loss(self, value: torch.Tensor) -> float:
        if not torch.distributed.is_available() or not torch.distributed.is_initialized():
            return float(value.detach().cpu())
        tensor = value.detach().clone()
        torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
        tensor = tensor / torch.distributed.get_world_size()
        return float(tensor.detach().cpu())

    def _append_jsonl(self, path: str, payload: Dict[str, Any]):
        with open(path, "a", encoding="utf-8") as f:
            f.write(json.dumps(payload, ensure_ascii=False) + "\n")

    def _cleanup_window_files(self):
        if bool(self.grad_probe_args.keep_grad_payload_files):
            # Retain files; just reset the tracker
            self._window_files = []
            return
        for fp in self._window_files:
            try:
                if os.path.exists(fp):
                    os.remove(fp)
            except Exception:
                # Best-effort cleanup
                pass
        self._window_files = []

    def _update_window_stats(self, grads_main, grads_skip, gm_norm2, gs_norm2, gf_norm2, dot_ms):
        window_size = int(self.grad_probe_args.grad_probe_window_size)
        if window_size <= 0:
            return None
        if self._window_sum_gm is None:
            self._window_sum_gm = [torch.zeros_like(g) for g in grads_main]
            self._window_sum_gs = [torch.zeros_like(g) for g in grads_skip]

        for idx, g in enumerate(grads_main):
            self._window_sum_gm[idx].add_(g)
        for idx, g in enumerate(grads_skip):
            self._window_sum_gs[idx].add_(g)

        self._window_sum_gm_norm2 += gm_norm2
        self._window_sum_gs_norm2 += gs_norm2
        self._window_sum_gf_norm2 += gf_norm2
        self._window_sum_dot_ms += dot_ms
        self._window_count += 1

        if self._window_count < window_size:
            return None

        m_hat = [g / window_size for g in self._window_sum_gm]
        s_hat = [g / window_size for g in self._window_sum_gs]
        m_hat_norm2 = self._sum_dot(m_hat, m_hat)
        s_hat_norm2 = self._sum_dot(s_hat, s_hat)
        dot_mhat_shat = self._sum_dot(m_hat, s_hat)
        avg_gs_norm2 = self._window_sum_gs_norm2 / window_size
        avg_gm_norm2 = self._window_sum_gm_norm2 / window_size
        avg_gf_norm2 = self._window_sum_gf_norm2 / window_size
        # Trace of covariance (variance): E[||g||^2] - ||E[g]||^2
        trace_sigma_m = max(0.0, avg_gm_norm2 - m_hat_norm2)
        trace_sigma_s = max(0.0, avg_gs_norm2 - s_hat_norm2)
        # Cross-covariance trace: 2 * (E[<g_m, g_s>] - <E[g_m], E[g_s]>)
        cc_trace = (2.0 / window_size) * (self._window_sum_dot_ms - window_size * dot_mhat_shat)
        delta_proxy = abs(cc_trace) / (trace_sigma_s + 1e-12)
        penalty = avg_gf_norm2 - avg_gm_norm2

        self._window_sum_gm = None
        self._window_sum_gs = None
        self._window_sum_gm_norm2 = 0.0
        self._window_sum_gs_norm2 = 0.0
        self._window_sum_gf_norm2 = 0.0
        self._window_sum_dot_ms = 0.0
        self._window_count = 0

        return {
            "m_hat_norm2": m_hat_norm2,
            "s_hat_norm2": s_hat_norm2,
            "dot_mhat_shat": dot_mhat_shat,
            "cos_mhat_shat": dot_mhat_shat / ((m_hat_norm2 * s_hat_norm2) ** 0.5 + 1e-12),
            "avg_gs_norm2": avg_gs_norm2,
            "avg_gm_norm2": avg_gm_norm2,
            "avg_gf_norm2": avg_gf_norm2,
            "trace_sigma_m": trace_sigma_m,
            "trace_sigma_s": trace_sigma_s,
            "variance_ratio": trace_sigma_s / (trace_sigma_m + 1e-12),
            "cc_trace": cc_trace,
            "delta_proxy": delta_proxy,
            "penalty": penalty,
            # Assumption 2: mean alignment
            "mean_alignment_dot": dot_mhat_shat,
            "mean_alignment_nonpositive": bool(dot_mhat_shat <= 0.0),
            "rho_mean_ratio": (s_hat_norm2 ** 0.5) / ((m_hat_norm2 ** 0.5) + 1e-12),
        }

    def _run_grad_probe(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]):
        if self._grad_probe_params is None:
            self._init_grad_probe_params()
            if self.is_world_process_zero():
                logger.info(f"GradProbe initialized. Output Dir: {self.grad_probe_args.grad_probe_output_dir}")
                logger.info(f"GradProbe Config: Window Size={self.grad_probe_args.grad_probe_window_size}, Log Every={self.grad_probe_args.grad_probe_log_every}")
                if self.grad_probe_args.grad_probe_window_size <= 0:
                    logger.warning("WARNING: grad_probe_window_size is 0! Window-based statistics (Trace, Variance Ratio, etc.) will NOT be computed or saved.")

        if not self._grad_probe_params:
            return

        checkpoint_states = self._toggle_checkpointing(model, enable=False)
        checkpoint_context = self._checkpoint_context()

        # Handle DDP wrapping for custom attributes
        model_container = model.module if hasattr(model, "module") else model
        original_detach = getattr(model_container, "detach_skiplink_layers", None)
        rng_state = self._get_rng_state()

        with checkpoint_context:
            with self.compute_loss_context_manager():
                if original_detach is not None:
                    model_container.detach_skiplink_layers = True
                loss_main = self.compute_loss(model, inputs)
            
            grads_main = torch.autograd.grad(
                loss_main,
                self._grad_probe_params,
                allow_unused=True,
                retain_graph=False,
            )

            self._set_rng_state(rng_state)
            with self.compute_loss_context_manager():
                if original_detach is not None:
                    model_container.detach_skiplink_layers = False
                loss_full = self.compute_loss(model, inputs)
            
            grads_full = torch.autograd.grad(
                loss_full,
                self._grad_probe_params,
                allow_unused=True,
                retain_graph=False,
            )

        grads_main, missing_main = self._replace_none_grads(grads_main, self._grad_probe_params)
        grads_full, missing_full = self._replace_none_grads(grads_full, self._grad_probe_params)

        if original_detach is not None:
            model_container.detach_skiplink_layers = original_detach

        self._toggle_checkpointing(model, enable=True, states=checkpoint_states)
        
        grads_skip = [gf - gm for gf, gm in zip(grads_full, grads_main)]

        gm_norm2 = self._sum_dot(grads_main, grads_main)
        gs_norm2 = self._sum_dot(grads_skip, grads_skip)
        gf_norm2 = self._sum_dot(grads_full, grads_full)
        dot_ms = self._sum_dot(grads_main, grads_skip)
        cos_ms = dot_ms / ((gm_norm2 * gs_norm2) ** 0.5 + 1e-12)
        ratio_2nd = gs_norm2 / (gm_norm2 + 1e-12)

        gm_norm2 = self._all_reduce_scalar(gm_norm2)
        gs_norm2 = self._all_reduce_scalar(gs_norm2)
        gf_norm2 = self._all_reduce_scalar(gf_norm2)
        dot_ms = self._all_reduce_scalar(dot_ms)
        cos_ms = dot_ms / ((gm_norm2 * gs_norm2) ** 0.5 + 1e-12)
        ratio_2nd = gs_norm2 / (gm_norm2 + 1e-12)

        save_dtype = self._get_save_dtype()
        grads_main = [g.detach().to(dtype=save_dtype).cpu() for g in grads_main]
        grads_full = [g.detach().to(dtype=save_dtype).cpu() for g in grads_full]
        grads_skip = [g.detach().to(dtype=save_dtype).cpu() for g in grads_skip]

        output_dir = self.grad_probe_args.grad_probe_output_dir
        if output_dir is None:
            output_dir = os.path.join(self.args.output_dir, "grad_probe")
        os.makedirs(output_dir, exist_ok=True)

        rank = 0
        world_size = 1
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            rank = torch.distributed.get_rank()
            world_size = torch.distributed.get_world_size()

        payload = {
            "global_step": int(self.state.global_step),
            "rank": rank,
            "world_size": world_size,
            "loss_main": float(loss_main.detach().cpu()),
            "loss_full": float(loss_full.detach().cpu()),
            "param_names": self._grad_probe_param_names,
            "missing_main": missing_main,
            "missing_full": missing_full,
            "grads_main": grads_main,
            "grads_full": grads_full,
            "grads_skip": grads_skip,
        }
        file_path = os.path.join(output_dir, f"grad_probe_step_{self.state.global_step}_rank{rank}.pt")
        if bool(self.grad_probe_args.save_grad_payload_files):
            try:
                torch.save(payload, file_path)
                self._window_files.append(file_path)
            except Exception:
                pass

        if self.is_world_process_zero():
            metrics = {
                "global_step": int(self.state.global_step),
                "loss_main": float(loss_main.detach().cpu()),
                "loss_full": float(loss_full.detach().cpu()),
                "gm_norm2": gm_norm2,
                "gs_norm2": gs_norm2,
                "gf_norm2": gf_norm2,
                "dot_ms": dot_ms,
                "cos_ms": cos_ms,
                "ratio_2nd": ratio_2nd,
            }
            self._append_jsonl(self._grad_probe_metrics_path, metrics)

            window_payload = self._update_window_stats(
                grads_main,
                grads_skip,
                gm_norm2,
                gs_norm2,
                gf_norm2,
                dot_ms,
            )
            if window_payload is not None:
                window_payload["global_step"] = int(self.state.global_step)
                self._append_jsonl(self._grad_probe_window_path, window_payload)
                logger.info(f"GradProbe: Window statistics saved to {self._grad_probe_window_path} at step {self.state.global_step}")
                # After finalizing a window, cleanup stored payload files if not retained
                self._cleanup_window_files()
            elif self.grad_probe_args.grad_probe_window_size > 0:
                # Log progress for window accumulation
                logger.info(f"GradProbe: Accumulating window stats ({self._window_count}/{self.grad_probe_args.grad_probe_window_size})")

        if self.grad_probe_args.one_step_enable:
            interval = max(int(self.grad_probe_args.one_step_every), 1)
            if (int(self.state.global_step) % interval) == 0:
                self._run_one_step_compare(model, inputs, grads_main, grads_full)

    def _run_one_step_compare(self, model, inputs, grads_main, grads_full):
        lr = None
        if self.optimizer is not None and len(self.optimizer.param_groups) > 0:
            lr = self.optimizer.param_groups[0].get("lr", None)
        if lr is None:
            return

        checkpoint_states = self._toggle_checkpointing(model, enable=False)
        checkpoint_context = self._checkpoint_context()
        params = self._grad_probe_params

        # Wrapper to handle ZeRO-3 parameter gathering
        def _execute_comparison():
            # Clone original data (now gathered if using ZeRO-3)
            original_data = [p.data.clone() for p in params]

            with checkpoint_context:
                with torch.no_grad():
                    loss0 = self.compute_loss(model, inputs)

            # --- Update 1: Main Gradients ---
            for p, g in zip(params, grads_main):
                # Ensure gradient is on the same device and type
                grad_device = g.to(device=p.device, dtype=p.dtype)
                p.data = p.data - (grad_device * lr)
            
            with torch.no_grad():
                loss_detach_after = self.compute_loss(model, inputs)

            # Restore original
            for p, orig in zip(params, original_data):
                p.data = orig.to(device=p.device)

            # --- Update 2: Full Gradients ---
            for p, g in zip(params, grads_full):
                grad_device = g.to(device=p.device, dtype=p.dtype)
                p.data = p.data - (grad_device * lr)
            
            with torch.no_grad():
                loss_full_after = self.compute_loss(model, inputs)

            # Restore original
            for p, orig in zip(params, original_data):
                p.data = orig.to(device=p.device)

            return loss0, loss_detach_after, loss_full_after

        # Execute with or without Zero-3 context
        if is_deepspeed_zero3_enabled():
            import deepspeed
            # Gather parameters for modification. modifier_rank=None means all ranks gather and can modify.
            with deepspeed.zero.GatheredParameters(params, modifier_rank=None):
                loss0, loss_detach_after, loss_full_after = _execute_comparison()
        else:
            loss0, loss_detach_after, loss_full_after = _execute_comparison()

        loss0_val = self._all_reduce_loss(loss0)
        loss_detach_val = self._all_reduce_loss(loss_detach_after)
        loss_full_val = self._all_reduce_loss(loss_full_after)

        loss0_val = self._all_reduce_loss(loss0)
        loss_detach_val = self._all_reduce_loss(loss_detach_after)
        loss_full_val = self._all_reduce_loss(loss_full_after)

        if self.is_world_process_zero():
            payload = {
                "global_step": int(self.state.global_step),
                "lr": float(lr),
                "loss0": loss0_val,
                "loss_detach_after": loss_detach_val,
                "loss_full_after": loss_full_val,
                "delta_detach": loss_detach_val - loss0_val,
                "delta_full": loss_full_val - loss0_val,
            }
            self._append_jsonl(self._one_step_path, payload)

        self._toggle_checkpointing(model, enable=True, states=checkpoint_states)

    def _toggle_checkpointing(self, model: nn.Module, enable: bool, states: Optional[Dict[str, Optional[bool]]] = None):
        if states is None:
            states = {
                "model": getattr(model, "gradient_checkpointing", None),
                "vision": getattr(getattr(model, "vision_model", None), "gradient_checkpointing", None),
                "language": getattr(getattr(model, "language_model", None), "gradient_checkpointing", None),
                "vision_transformer": getattr(getattr(getattr(model, "vision_model", None), "transformer", None), "grad_checkpointing", None),
            }

        def _apply_attr(obj, attr, value):
            if obj is None:
                return
            if hasattr(obj, attr):
                setattr(obj, attr, value)

        def _apply_checkpointing(obj, value):
            if obj is None:
                return

            if hasattr(obj, "gradient_checkpointing_enable") and hasattr(obj, "gradient_checkpointing_disable"):
                try:
                    if value:
                        obj.gradient_checkpointing_enable()
                    else:
                        obj.gradient_checkpointing_disable()
                    return
                except ValueError:
                    # Some models (like PeViT) might claim incompatibility or lack proper support flags at runtime.
                    # If this fails, we fall through to manual handling if available, or just ignore 
                    # if the model doesn't actually implement the disable logic correctly but we handle submodules manually.
                    pass

            if hasattr(obj, "_set_gradient_checkpointing"):
                try:
                    obj._set_gradient_checkpointing(value)
                    return
                except ValueError:
                    pass

            if hasattr(obj, "gradient_checkpointing"):
                obj.gradient_checkpointing = value

        if enable:
            if states.get("model") is not None:
                _apply_checkpointing(model, states["model"])
            if states.get("vision") is not None:
                _apply_checkpointing(model.vision_model, states["vision"])
            if states.get("language") is not None:
                _apply_checkpointing(model.language_model, states["language"])
            if states.get("vision_transformer") is not None:
                if hasattr(model.vision_model, "transformer") and hasattr(model.vision_model.transformer, "set_grad_checkpointing"):
                    model.vision_model.transformer.set_grad_checkpointing(states["vision_transformer"])
                else:
                    _apply_attr(model.vision_model.transformer, "grad_checkpointing", states["vision_transformer"])
            return states

        _apply_checkpointing(model, False)
        if hasattr(model, "vision_model"):
            _apply_checkpointing(model.vision_model, False)
            if hasattr(model.vision_model, "transformer"):
                if hasattr(model.vision_model.transformer, "set_grad_checkpointing"):
                    model.vision_model.transformer.set_grad_checkpointing(False)
                else:
                    _apply_attr(model.vision_model.transformer, "grad_checkpointing", False)
        if hasattr(model, "language_model"):
            _apply_checkpointing(model.language_model, False)
        return states

    def _checkpoint_context(self):
        if hasattr(torch.utils.checkpoint, "set_checkpoint_use_reentrant"):
            return torch.utils.checkpoint.set_checkpoint_use_reentrant(False)
        return contextlib.nullcontext()

    # def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
    def training_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        num_items_in_batch: Optional[int] = None,
    ) -> torch.Tensor:
        model.train()
        inputs = self._prepare_inputs(inputs)

        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)

        if self._should_probe():
            self._run_grad_probe(model, inputs)
            self._grad_probe_last_step = self.state.global_step

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        del inputs
        torch.cuda.empty_cache()

        if self.args.n_gpu > 1:
            loss = loss.mean()

        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss)

        return loss.detach() / self.args.gradient_accumulation_steps