"""Code for a privacy engine that enables deep learning with differential privacy

Design mostly based on Opacus and Private-transformers, and should work with 
most libraries such as huggingface, timm, torchvision, etc.

MODIFICATIONS FOR FILTER-AWARE OPTIMIZERS:
- Auto-detects optimizers with filter-aware DP correction (AdamFusion, KFOptimizer, etc.)
- Automatically updates dp_noise_std after each clipping operation
- Supports both AdamFusion (residual filter) and KFOptimizer (state filter)
- Works with both fixed and automatic clipping
- Compatible with optimizer wrappers
"""

import logging
import math
import types
from typing import Dict, Optional, Sequence, Union

import torch
from torch import nn

from . import autograd_grad_sample, transformers_support
from .accounting import accounting_manager
from torch.functional import F
import transformers
from .supported_layers_grad_samplers import _supported_layers_norm_sample_AND_clipping


class PrivacyEngine(object):
    """Differentially-private optimization engine that works in Pytorch.

    Supports book-keeping (BK) algorithm -- base and hybrid variants, as described in arXiv:2210.00038
    Supports DP-BiTFiT (bias-term only fine-tuning, which does not use BK), as described in arXiv:2210.00036
    
    ENHANCED: Automatically configures filter-aware optimizers with DP-AdamBC correction.
    Supports:
    - AdamFusion (residual/innovation filter): A(ω) = (2-ω)/(4-3ω)
    - KFOptimizer (state EMA filter): A(κ) = κ/(2-κ)
    """

    def __init__(
        self,
        module: nn.Module,
        *,
        batch_size: int,
        sample_size: int,
        max_grad_norm: float = 1.,
        epochs: Optional[Union[int, float]] = None,
        num_steps: Optional[Union[int, float]] = None,
        noise_multiplier: Optional[float] = None,
        target_epsilon: Optional[float] = None,
        target_delta: Optional[float] = None,
        alphas: Sequence[float] = accounting_manager.DEFAULT_ALPHAS,
        record_snr: bool = False,
        named_params: Optional[Sequence] = None,
        numerical_stability_constant=None,
        accounting_mode="rdp",
        eps_error=0.05,
        clipping_mode='MixOpt',
        clipping_fn='automatic',
        loss_reduction='mean',
        origin_params=None,
        clipping_style='all-layer',
        num_GPUs=1,
        torch_seed_is_fixed=False,
        **unused_kwargs,
    ):

        """Initialize the engine.

        Args:
            module: The PyTorch module for which per-sample gradient is required.
                Setting the `requires_grad` attribute of a parameter to False
                disables the per-sample gradient accumulation.
            batch_size: The expected size of a logical batch.
            sample_size: Size of dataset.
            max_grad_norm: The maximum 2-norm for gradient clipping.
            epochs: The number of epochs for training.
            num_steps: The number of steps for training, only used if epochs is None.
            noise_multiplier: The extra multiplier for DP-SGD noise.
            target_epsilon: The target privacy spending.
                Only used to estimate the `noise_multiplier` if it is not set.
            target_delta: The target failure probability.
                Defaults to sample_size ** -1.1 if not set.
            alphas: The RDP orders for (ε, δ)-DP conversion. Useless if not accounting in RDP.
            record_snr: Record and report the signal-to-noise ratio --
                ratio between norm of summed clipped gradient and norm of noise vector.
            named_params: Specifies which parameters need gradients;
                defaults to use parameters which require grad in module.
            numerical_stability_constant: Small constant to avoid division by 0 when clipping.
            accounting_mode: The method of accounting privacy. One of (`rdp`, `glw`, `all`).
                Meanings of shorthands:
                    - rdp: Account loss with RDP but perform conversion to approx-DP with a procedure defined in
                        "The Discrete Gaussian for Differential Privacy". https://arxiv.org/abs/2004.00010
                    - glw: Account loss by numerically composing tradeoff functions in f-DP; defined in
                        "Numerical composition of differential privacy". https://arxiv.org/abs/2106.02848
                    - all: Report loss with all methods listed above.
            eps_error: Error threshold for upper and lower bound in the GLW accounting procedure.
            clipping_mode: The clipping mode to use. One of 'ghost' (BK), 'MixGhostClip', 'MixOpt'.
            clipping_fn: Per-sample gradient clipping function to use. One of 'Abadi','automatic','global'
            loss_reduction: Reduction of loss, one of 'sum' and 'mean'.
            origin_params: Specifies which are origin parameters as described in ghost differentiation. Can be None or list of parameter names
                ['_embeddings','wte','wpe'] is used for roberta and GPT2. For general model, can set to first layer's bias or weight.
            clipping_style: The clipping style to use. One of 'all-layer', 'layer-wise', 'param-wise' or an un-ordered list of layer names that represent blocks' head layer
        """
        del unused_kwargs
        super(PrivacyEngine, self).__init__()

        if clipping_mode not in ['ghost','MixGhostClip','MixOpt']:
            raise ValueError(f"Unknown clipping mode {clipping_mode}. Expected one of 'ghost','MixGhostClip','MixOpt'.")
        if accounting_mode not in ("rdp", "all",'glw'):
            raise ValueError(f"Unknown accounting mode: {accounting_mode}. Expected one of 'rdp', 'all','glw'.")
        if epochs is None:
            if num_steps is not None:
                epochs=num_steps/sample_size*batch_size
            else:
                raise ValueError(f"Number of training epochs and training steps are not defined.")
        if epochs <= 0.0 and noise_multiplier is None:
            raise ValueError(f"Number of training epochs cannot be non-positive, but found epochs={epochs}")

        # Privacy parameters.
        sample_rate = batch_size / sample_size
        if target_delta is None:
            target_delta = 1 / (2 * sample_size)
        if noise_multiplier is None:
            if target_epsilon is None or epochs is None:
                raise ValueError(
                    f"`target_epsilon` and `epochs` must be specified when `noise_multiplier` is `None`."
                )
            if accounting_mode in ("rdp", "all"):
                manager = accounting_manager.RDPManager(alphas=alphas)
            else:  # "glw"
                manager = accounting_manager.GLWManager(eps_error=eps_error)
            noise_multiplier = manager.compute_sigma(
                target_epsilon=target_epsilon, target_delta=target_delta, sample_rate=sample_rate, epochs=epochs,
            )

        self.batch_size = batch_size
        self.sample_size = sample_size
        self.sample_rate = sample_rate
        self.max_grad_norm = max_grad_norm

        self.epochs = epochs
        self.noise_multiplier = noise_multiplier
        self.effective_noise_multiplier = noise_multiplier / batch_size
        self.target_epsilon = target_epsilon
        self.target_delta = target_delta
        self.alphas = alphas
        self.eps_error = eps_error
        self.accounting_mode = accounting_mode
        self.record_snr = record_snr

        # Internals.
        self.steps = 0  # Tracks privacy spending.

        # Recording.
        self.max_clip = None
        self.min_clip = None
        self.med_clip = None
        self.signal = None
        self.noise = None
        self.snr = None
        self.noise_limit = None
        
        # Record parameters.
        self.module = module
        if named_params is None:
            self.named_params = list(
                (name, param) for (name, param) in module.named_parameters() if param.requires_grad
            )
        else:
            self.named_params = named_params
        self.num_params = sum(param.numel() for _, param in self.named_params)

        self._locked = False  # lock the part where noisy gradients is created (in `self.step`) if True.

        # === ADAMFUSION SUPPORT: Initialize to None, will be set in attach() ===
        self._adamfusion_optimizer = None
        self._adamfusion_last_clip_norm = None
        # ========================================================================

        #----- ghost differentiation trick through origin parameter
        for name,param in module.named_parameters():
            param.initially_requires_grad=bool(param.requires_grad)
            if origin_params!=None:
                param.requires_grad=param.initially_requires_grad and any([i in name for i in origin_params]) # only requires grad if it is origin and initially requires grad

        if origin_params!=None:
            print('Using origin parameters for the ghost differentiation trick......')

        #-----
        def _supported_and_trainable(layer):            
            if type(layer) in _supported_layers_norm_sample_AND_clipping and ((hasattr(layer,'weight') and hasattr(layer.weight,'initially_requires_grad') and layer.weight.initially_requires_grad) or (hasattr(layer,'bias') and hasattr(layer.bias,'initially_requires_grad') and layer.bias.initially_requires_grad)):
                return True
            return False

        # store layer's name and create list of named layers for blockwise clipping
        self.named_layers=[]
        for name,layer in module.named_modules():
            if _supported_and_trainable(layer):
                self.named_layers.append((name,layer))

        self.n_layers=len(self.named_layers)
        
        self.n_components=0
        for name, layer in self.named_layers:
            self.n_components+=sum([1 for p in layer.parameters() if p.initially_requires_grad])
        print("Number of trainable components: ",self.n_components, "; Number of trainable layers: ",self.n_layers)

        #-----
        print('>>>>>>>>>>>>>>>>> Applying ',clipping_fn, ' per-sample gradient clipping.')
        self.clipping_fn = clipping_fn
        if numerical_stability_constant!=None:
            self.numerical_stability_constant = numerical_stability_constant
        elif self.clipping_fn=='automatic':
            self.max_grad_norm = 1. # max_grad_norm does not matter in automatic clipping; this is necessary for step()
            self.numerical_stability_constant=1e-2
        else:
            self.numerical_stability_constant=1e-6
        
        if clipping_style=='layer-wise':
            self.max_grad_norm_layerwise = self.max_grad_norm / math.sqrt(self.n_layers)
        elif clipping_style=='param-wise':
            self.max_grad_norm_layerwise = self.max_grad_norm / math.sqrt(self.n_components)
        elif clipping_style=='all-layer':
            self.max_grad_norm_layerwise=self.max_grad_norm
        else:
            self.max_grad_norm_layerwise=self.max_grad_norm / math.sqrt(len(clipping_style))

        for name,param in module.named_parameters():
            param.batch_size = self.batch_size
            if torch_seed_is_fixed == True:
                param.noise = self.noise_multiplier*self.max_grad_norm / num_GPUs
            else:
                param.noise = self.noise_multiplier*self.max_grad_norm / math.sqrt(num_GPUs)

        self.loss_reduction = loss_reduction
        self.clipping_mode = clipping_mode
        
        #----- determine whether training with BiTFiT
        self.bias_only=True
        for name,param in self.named_params:
            if '.bias' not in name and param.requires_grad:
                self.bias_only=False; break

        if self.bias_only:
            origin_params=None # do not use origin parameters for BiTFiT
            
        # create list of block head layers        
        if isinstance(clipping_style,list):
            self.clipping_style='block-wise'
            self.block_heads=clipping_style
        else:            
            self.clipping_style=clipping_style
            self.block_heads=[]
        
            if self.clipping_style=='all-layer':
                self.block_heads.append(self.named_layers[0][0])
            elif self.clipping_style in ['layer-wise','param-wise']:
                self.block_heads = [name for (name,layer) in self.named_layers]
        print(">>>>>>>>>>>>>>>>> Block heads for per-sample gradient clipping are defined as:", self.block_heads)

        transformers_support.forward_swapper(module=module)  # fix the position embeddings broadcast issue.

        autograd_grad_sample.add_hooks(model=self.module, loss_reduction=self.loss_reduction, 
                                       clipping_mode=self.clipping_mode, bias_only=self.bias_only,
                                       clipping_style=self.clipping_style, block_heads=self.block_heads,
                                       named_params=self.named_params, named_layers=self.named_layers,
                                       clipping_fn=self.clipping_fn, 
                                       numerical_stability_constant=self.numerical_stability_constant,
                                       max_grad_norm_layerwise=self.max_grad_norm_layerwise)

        def get_privacy_spent(_self, **kwargs):
            return _self.privacy_engine.get_privacy_spent(**kwargs)

        def get_training_stats(_self, **kwargs):
            return _self.privacy_engine.get_training_stats(**kwargs)

        # Make getting info easier.
        self.module.get_privacy_spent = types.MethodType(get_privacy_spent, self.module)
        self.module.get_training_stats = types.MethodType(get_training_stats, self.module)

        self.module.privacy_engine = self

        # ------ deepspeed ZERO 1 modification-----------
        try:
            from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
            from deepspeed import comm as dist
    
            def reduce_gradients_DP_stage_1(self, pipeline_parallel=False):
                world_size = dist.get_world_size(self.dp_process_group)
                my_rank = dist.get_rank(self.dp_process_group)
    
                # with PP we must create ipg buffer, since backward is handled outside zero
                if pipeline_parallel and self.contiguous_gradients:
                    self.ipg_buffer = []
                    buf_0 = torch.empty(int(self.reduce_bucket_size),
                                        dtype=self.dtype,
                                        device=torch.cuda.current_device())
                    self.ipg_buffer.append(buf_0)
                    self.ipg_index = 0
    
                if not self.overlap_comm:
                    for i, group in enumerate(self.bit16_groups):
                        for param in group:
                            if param.grad is not None:
                                if hasattr(param,'private_grad'):
                                    param.grad = torch.nan_to_num(param.private_grad).contiguous()
                                    del param.private_grad # release memory
                                    param.grad = param.grad / param.batch_size * self.loss_scale # it works
                                else:
                                    param.grad.zero_()
    
                                self.reduce_ready_partitions_and_remove_grads(param, i)
                # reduce any pending grads in either hook/non-hook case
                self.overlapping_partition_gradients_reduce_epilogue()
    
            DeepSpeedZeroOptimizer.reduce_gradients = reduce_gradients_DP_stage_1
        except:
            pass

    def lock(self):
        """Run this after noisy clipped gradient is created to prevent tampering with it before parameter update."""
        self._locked = True

    def unlock(self):
        """Run this after parameter update to allow creation of noisy gradient for next step"""
        self._locked = False

    def attach(self, optimizer):
        """Attach privacy engine to optimizer.
        
        ENHANCED: Auto-detects AdamFusion and configures filter-aware DP-AdamBC.
        """
        # Override step.
        def dp_step(_self, **kwargs):
            closure = kwargs.pop("closure", None)
            
            _self.zero_grad()         # make sure no non-private grad remains
            _self.privacy_engine._create_noisy_clipped_gradient(**kwargs)
            _self.original_step(closure=closure)
            _self.privacy_engine.unlock()  # Only enable creating new grads once parameters are updated.
            _self.privacy_engine.steps += 1

        optimizer.privacy_engine = self
        optimizer.original_step = optimizer.step
        optimizer.step = types.MethodType(dp_step, optimizer)        

        def get_privacy_spent(_self, **kwargs):
            return _self.privacy_engine.get_privacy_spent(**kwargs)

        def get_training_stats(_self, **kwargs):
            return _self.privacy_engine.get_training_stats(**kwargs)

        # Make getting info easier.
        optimizer.get_privacy_spent = types.MethodType(get_privacy_spent, optimizer)
        optimizer.get_training_stats = types.MethodType(get_training_stats, optimizer)

        self.optimizer = optimizer
        
        # =====================================================================
        # FILTER-AWARE OPTIMIZER DETECTION (AdamFusion, KFOptimizer, etc.)
        # Detects any optimizer with set_dp_noise_std() method and configures
        # filter-aware DP-AdamBC correction automatically.
        # =====================================================================
        actual_optimizer = optimizer
        
        # First check if the wrapper itself has set_dp_noise_std (e.g., KFOptimizer)
        if hasattr(optimizer, 'set_dp_noise_std'):
            self._adamfusion_optimizer = optimizer
            wrapper_depth = 0
        else:
            # Otherwise unwrap to find base optimizer (e.g., AdamFusion)
            wrapper_depth = 0
            while hasattr(actual_optimizer, 'optimizer') and wrapper_depth < 10:
                actual_optimizer = actual_optimizer.optimizer
                wrapper_depth += 1
            
            # Check if unwrapped optimizer has set_dp_noise_std method
            if hasattr(actual_optimizer, 'set_dp_noise_std'):
                self._adamfusion_optimizer = actual_optimizer
            else:
                self._adamfusion_optimizer = None
        
        # If we found a filter-aware optimizer, configure it
        if self._adamfusion_optimizer is not None:
            
            # Compute initial dp_noise_std
            initial_dp_noise_std = self.max_grad_norm * self.noise_multiplier
            self._adamfusion_last_clip_norm = self.max_grad_norm
            
            # Set initial value
            self._adamfusion_optimizer.set_dp_noise_std(initial_dp_noise_std)
            
            # Detect optimizer type and extract relevant parameters
            optimizer_name = type(self._adamfusion_optimizer).__name__
            omega = None
            kappa = None
            A_factor = None
            bias_value = None
            use_bc = False
            filter_type = "Unknown"
            
            if hasattr(self._adamfusion_optimizer, 'param_groups') and len(self._adamfusion_optimizer.param_groups) > 0:
                group = self._adamfusion_optimizer.param_groups[0]
                
                # Check for AdamFusion (residual/innovation filter)
                omega = group.get('omega', None)
                if omega is not None:
                    filter_type = "Residual/Innovation (FIBER)"
                    use_bc = group.get('use_filter_aware_adambc', False)
                    if use_bc:
                        A_factor = (2.0 - omega) / (4.0 - 3.0 * omega)
                        bias_value = (initial_dp_noise_std ** 2) * A_factor
                
                # Check for KFOptimizer/StateFilter
                kappa = group.get('kappa', None)
                if kappa is not None and omega is None:
                    filter_type = "State EMA (DiSK/KF-style)"
                    use_bc = group.get('use_filter_aware_correction', False)
                    if use_bc:
                        A_factor = kappa / (2.0 - kappa)
                        bias_value = (initial_dp_noise_std ** 2) * A_factor
            
            # Print configuration
            print(f"\n{'='*75}")
            print(f"{'PrivacyEngine: Filter-Aware Optimizer Integration Activated':^75}")
            print(f"{'='*75}")
            print(f"  Optimizer type:           {optimizer_name}")
            print(f"  Filter type:              {filter_type}")
            if wrapper_depth > 0:
                print(f"  Wrapper depth:            {wrapper_depth}")
            print(f"  Clipping function:        {self.clipping_fn}")
            print(f"  Clipping style:           {self.clipping_style}")
            print(f"  Initial clip norm (C):    {self.max_grad_norm:.6f}")
            print(f"  Noise multiplier (σ):     {self.noise_multiplier:.6f}")
            print(f"  Initial dp_noise_std:     {initial_dp_noise_std:.6f}")
            print(f"  -" * 37)
            
            if use_bc:
                print(f"  Filter-Aware Correction:  ENABLED")
                if omega is not None and A_factor is not None:
                    print(f"  Omega (ω):                {omega:.6f}")
                    print(f"  A(ω) = (2-ω)/(4-3ω):      {A_factor:.6f}")
                    print(f"  Theoretical bias:         {bias_value:.6f}")
                elif kappa is not None and A_factor is not None:
                    print(f"  Kappa (κ):                {kappa:.6f}")
                    print(f"  A(κ) = κ/(2-κ):           {A_factor:.6f}")
                    print(f"  Theoretical bias:         {bias_value:.6f}")
                else:
                    print(f"  Filter parameters:        Not found in param_groups")
            else:
                print(f"  Filter-Aware Correction:  DISABLED")
            
            print(f"  -" * 37)
            if self.clipping_fn == 'automatic':
                print(f"  Auto-update mode:         ACTIVE")
                print(f"  └─ dp_noise_std will be updated every step based on adaptive clipping")
            else:
                print(f"  Auto-update mode:         STATIC")
                print(f"  └─ dp_noise_std remains constant (fixed clipping norm)")
            print(f"{'='*75}\n")
            
        else:
            self._adamfusion_optimizer = None
            # Not AdamFusion, no special handling needed
        # =====================================================================

    def detach(self):
        optimizer = self.optimizer
        optimizer.step = optimizer.original_step
        delattr(optimizer, "privacy_engine")
        delattr(optimizer, "original_step")
        delattr(optimizer, "get_privacy_spent")
        delattr(optimizer, "get_training_stats")

        module = self.module
        autograd_grad_sample.remove_hooks(module)
        module.zero_grad()

        for layer in self.module.modules():
            if hasattr(layer,'activations'):
                del layer.activations
            if hasattr(layer,'backprops'):
                del layer.backprops
            for param in layer.parameters():
              if hasattr(param,'private_grad'):
                del param.private_grad

    def _create_noisy_clipped_gradient(self):
        """Create noisy clipped gradient for `optimizer.step`.
        
        ENHANCED: Updates AdamFusion's dp_noise_std after clipping (crucial for automatic clipping).
        """
        
        unsupported_param_name=[]
        for name,param in list(self.named_params):
            if not hasattr(param, 'private_grad'):
                unsupported_param_name.append(name)
                self.named_params.remove((name,param))
        if unsupported_param_name!=[]:
            print(unsupported_param_name, 'are not supported by privacy engine; these parameters are not requiring gradient nor updated.')
                
        signals, noises = [], []
        
        for name,param in self.named_params:
            param.grad = param.private_grad  # Ultra important to override `.grad`.
            del param.private_grad

            if self.record_snr:
                signals.append(param.grad.reshape(-1).norm(2))
            if self.loss_reduction=='mean':
                param.grad /= self.batch_size                

        if self.record_snr and len(noises) > 0:
            self.signal, self.noise = tuple(torch.stack(lst).norm(2).item() for lst in (signals, noises))
            self.noise_limit = math.sqrt(self.num_params) * self.noise_multiplier * self.max_grad_norm
            self.snr = self.signal / self.noise
        else:
            self.snr = math.inf  # Undefined!

        # =====================================================================
        # ADAMFUSION: UPDATE DP_NOISE_STD AFTER CLIPPING
        # =====================================================================
        if self._adamfusion_optimizer is not None:
            # Compute current dp_noise_std based on (possibly updated) max_grad_norm
            current_dp_noise_std = self.max_grad_norm * self.noise_multiplier
            
            # Update AdamFusion's dp_noise_std
            self._adamfusion_optimizer.set_dp_noise_std(current_dp_noise_std)
            
            # Log if clipping norm changed (useful for automatic clipping)
            if (self._adamfusion_last_clip_norm is not None and 
                abs(self._adamfusion_last_clip_norm - self.max_grad_norm) > 1e-6 and
                self.steps % 100 == 0):  # Log every 100 steps to avoid spam
                print(f"[Step {self.steps}] Adaptive clipping: "
                      f"C={self.max_grad_norm:.6f}, dp_noise_std={current_dp_noise_std:.6f}")
            
            self._adamfusion_last_clip_norm = self.max_grad_norm
        # =====================================================================

        self.lock()  # Make creating new gradients impossible, unless optimizer.step is called.

    def get_privacy_spent(
        self,
        steps: Optional[int] = None,
        accounting_mode: Optional[str] = None,
        lenient=False
    ) -> Dict:
        if steps is None:
            steps = self.steps
        if accounting_mode is None:
            accounting_mode = self.accounting_mode

        privacy_results = {}  # Contains stats from all modes.
        if accounting_mode in ('all','rdp'):
            try:
                manager = accounting_manager.RDPManager(alphas=self.alphas)
                privacy_results.update(
                    manager.compute_epsilon(
                        sigma=self.noise_multiplier,
                        sample_rate=self.sample_rate,
                        target_delta=self.target_delta,
                        steps=steps,
                    )
                )
            except Exception as err:
                logging.fatal("RDP accounting failed! Double check privacy parameters.")
                if not lenient:
                    raise err

        if accounting_mode in ('all','glw'):
            try:
                manager = accounting_manager.GLWManager(eps_error=self.eps_error)
                privacy_results.update(
                    manager.compute_epsilon(
                        sigma=self.noise_multiplier,
                        sample_rate=self.sample_rate,
                        target_delta=self.target_delta,
                        steps=steps
                    )
                )
            except Exception as err:
                logging.fatal(
                    "Numerical composition of tradeoff functions failed! Double check privacy parameters."
                )
                if not lenient:
                    raise err

        return privacy_results

    def get_training_stats(self):
        """Get the clipping, signal, and noise statistics."""
        stats = {
            "med_clip": self.med_clip,
            "max_clip": self.max_clip,
            "min_clip": self.min_clip,
            "snr": self.snr,
            "signal": self.signal,
            "noise": self.noise,
            "noise_limit": self.noise_limit,
        }
        
        # Add AdamFusion-specific stats if available
        if self._adamfusion_optimizer is not None:
            stats["adamfusion_enabled"] = True
            stats["current_dp_noise_std"] = self.max_grad_norm * self.noise_multiplier
            stats["current_clip_norm"] = self.max_grad_norm
        else:
            stats["adamfusion_enabled"] = False
            
        return stats

    def __repr__(self):
        adamfusion_str = ""
        if self._adamfusion_optimizer is not None:
            adamfusion_str = f"  adamfusion_integration=ACTIVE, \n"
            
        return (
            f"PrivacyEngine(\n"
            f"  target_epsilon={self.target_epsilon:.6f}, \n"
            f"  target_delta={self.target_delta:.6f}, \n"
            f"  noise_multiplier={self.noise_multiplier:.6f}, \n"
            f"  effective_noise_multiplier={self.effective_noise_multiplier:.6f}, \n"
            f"  epochs={self.epochs}, \n"
            f"  max_grad_norm={self.max_grad_norm}, \n"
            f"  sample_rate={self.sample_rate}, \n"
            f"  batch_size={self.batch_size}, \n"
            f"  accounting_mode={self.accounting_mode}, \n"
            f"  clipping_mode={self.clipping_mode}, \n"
            f"{adamfusion_str}"
            f")"
        )