# HealthCheckCallback.py

import torch
# --- Import necessary types from the transformers library ---
from transformers import (
    TrainerCallback,
    TrainingArguments,
    TrainerState,
    TrainerControl
)
from transformers.utils import logging

logger = logging.get_logger(__name__)

class HealthCheckCallback(TrainerCallback):
    """
    A callback that monitors the numerical stability of model weights during training.
    It stops the training if NaN or Inf values are detected in any parameter.
    """

    def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
        """
        Called at the end of each training step (after the gradient update).
        """
        if not state.is_world_process_zero:
            return  # Only perform the check on the main process.

        model = kwargs.get("model")
        if model is None:
            logger.warning("HealthCheckCallback: Model not found in kwargs. Cannot perform health check.")
            return

        for name, param in model.named_parameters():
            if not torch.isfinite(param).all():
                # NaN or Inf detected.
                nan_detected = torch.isnan(param).any()
                inf_detected = torch.isinf(param).any()
                
                error_message = (
                    f"!!! Numerical Instability Detected !!!\n"
                    f"Invalid values found in parameter '{name}' at step {state.global_step}.\n"
                    f"  - Contains NaN: {nan_detected.item()}\n"
                    f"  - Contains Inf: {inf_detected.item()}\n"
                    f"Training will be stopped immediately."
                )
                
                logger.error(error_message)
                
                # Set the control flag to tell the Trainer to stop training.
                control.should_training_stop = True
                break  # One is enough, break the loop.

        # Check control.should_training_stop to ensure we exit only when needed.
        if control.should_training_stop:
            return control # Return the updated control object.

    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", model=None, **kwargs):
        """
        Called at the end of training to print a final message.
        """
        final_check_failed = False
        if model: # Check the model one last time at the end of training.
            for name, param in model.named_parameters():
                if not torch.isfinite(param).all():
                    print(f"!!! Final Check Failed: Parameter '{name}' contains NaN/Inf after training finished.")
                    final_check_failed = True
                    break

        if control.should_training_stop or final_check_failed:
            print("\n" + "="*50)
            print("Training was stopped early due to detected numerical instability (NaN/Inf).")
            print(f"Please check the training logs and configuration, especially around step {state.global_step}.")
            print("Common causes: high learning rate, lack of gradient clipping, mixed-precision issues.")
            print("="*50 + "\n")