import torch


class DiagnosticHook:
    def __init__(self, saturation_threshold=0.55, damping_tol=1e-8):
        self.raw_steps = []
        self.softsign_steps = []
        self.saturation_threshold = saturation_threshold
        self.damping_tol = damping_tol
        self.t = None

    def __call__(self, raw_step, softsign_step):
        self.raw_steps.append(raw_step.detach().flatten())
        self.softsign_steps.append(softsign_step.detach().flatten())
    
    def log_temperature(self, value):
        self.t = value

    def compute_and_reset(self):
        if not self.raw_steps:
            return {}
        
        raw_steps = torch.cat(self.raw_steps)
        softsign_steps = torch.cat(self.softsign_steps)
        
        metrics = {
            'saturation_ratio': (softsign_steps.abs() > self.saturation_threshold).float().mean().item(),
            'damping_ratio': (softsign_steps.abs() < raw_steps.abs() - self.damping_tol).float().mean().item(),
            'preservation_ratio': (softsign_steps.abs() / (raw_steps.abs() + 1e-8)).mean().item(),
            'temperature': self.t
        }
        
        self.raw_steps.clear()
        self.softsign_steps.clear()
        return metrics