import numpy as np
import torch
import torch.nn as nn

from torch.nn import functional as F
from typing import Dict, Callable, List, Optional, Tuple, Union
import sys

from offlinerlkit.policy import TD3Policy
from offlinerlkit.utils.noise import GaussianNoise
from offlinerlkit.utils.scaler import StandardScaler


import torch
import torch.nn.functional as F

@torch.no_grad()
def project_linear_spectral_norm_(module: torch.nn.Module, c: float = 1.0, n_iter: int = 1):
    """
    Project each nn.Linear weight W to satisfy ||W||_2 <= c (approx, via power iteration).
    """
    for m in module.modules():
        if isinstance(m, torch.nn.Linear):
            W = m.weight  # [out, in]
            # power iteration: approximate top singular value
            u = torch.randn(W.shape[0], device=W.device)
            u = F.normalize(u, dim=0, eps=1e-12)
            for _ in range(n_iter):
                v = F.normalize(W.T @ u, dim=0, eps=1e-12)
                u = F.normalize(W @ v, dim=0, eps=1e-12)
            sigma = (u @ (W @ v)).abs()
            if sigma > c:
                m.weight.mul_(c / (sigma + 1e-12))



# ============================================================
# Utilities: extract Adam preconditioner d, build Z, compute S
# ============================================================

def _optimizer_trainable_params_in_order(optim: torch.optim.Optimizer) -> List[torch.nn.Parameter]:
    """
    Return a list of all parameters with requires_grad=True in the order of optimizer.param_groups traversal.
    This ensures consistency with the concatenation order of d_flat (avoiding misalignment).
    """
    params: List[torch.nn.Parameter] = []
    for group in optim.param_groups:
        for p in group["params"]:
            if isinstance(p, torch.nn.Parameter) and p.requires_grad:
                params.append(p)
    return params


@torch.no_grad()
def extract_adam_d_like_params(
    optim: torch.optim.Optimizer,
    *,
    strict: bool = True,
    bias_correction: bool = True,
    clamp_bias_correction: float = 1e-16,
    use_amsgrad: Optional[bool] = None,
) -> List[torch.Tensor]:
    """
    Extract d-like parameters compatible with "native torch.optim.Adam/AdamW (unpatched)":
        d = 1 / (sqrt(v_hat) + eps)
        v_hat = v / (1 - beta2^t)   (when bias_correction=True)

    Returns d_list: strictly in the order of optim.param_groups traversal;
    but only for parameters with requires_grad=True (since autograd.grad cannot compute gradients for requires_grad=False).

    If you have already added get_d_like_params/get_d_flat to the Adam class, you may not need this function.
    """
    d_list: List[torch.Tensor] = []

    for group in optim.param_groups:
        betas = group.get("betas", None)
        if betas is None or len(betas) < 2:
            raise RuntimeError("Optimizer param_group has no 'betas' (not an Adam-like optimizer?).")
        beta2: float = float(betas[1])
        eps: float = float(group.get("eps", 1e-8))

        amsgrad_group: bool = bool(group.get("amsgrad", False))
        amsgrad: bool = amsgrad_group if use_amsgrad is None else bool(use_amsgrad)

        for p in group["params"]:
            if not isinstance(p, torch.nn.Parameter) or (not p.requires_grad):
                continue

            st = optim.state.get(p, None)
            if st is None or len(st) == 0:
                if strict:
                    raise RuntimeError(
                        "Adam state is empty for some parameters. "
                        "Call backward() + optimizer.step() at least once before extracting D."
                    )
                d_list.append(torch.ones_like(p, memory_format=torch.preserve_format))
                continue

            # exp_avg_sq / max_exp_avg_sq
            if amsgrad:
                v = st.get("max_exp_avg_sq", None)
                if v is None:
                    raise RuntimeError("AMSGrad is enabled but state has no 'max_exp_avg_sq'.")
            else:
                v = st.get("exp_avg_sq", None)
                if v is None:
                    raise RuntimeError("Adam state has no 'exp_avg_sq'.")

            step = st.get("step", None)
            if step is None:
                raise RuntimeError("Adam state has no 'step'.")

            # step could be int or 0-dim tensor
            if torch.is_tensor(step):
                step_t = step.to(device=v.device, dtype=v.dtype)
            else:
                step_t = torch.tensor(float(step), device=v.device, dtype=v.dtype)

            if bias_correction:
                bc2 = 1.0 - (beta2 ** step_t)
                bc2 = torch.clamp(bc2, min=clamp_bias_correction)
                v_hat = v / bc2
            else:
                v_hat = v

            d = 1.0 / (torch.sqrt(v_hat) + eps)
            d_list.append(d)

    return d_list

@torch.no_grad()
def get_adam_update_stats(optim: torch.optim.Optimizer) -> Dict[str, float]:
    """
    Compute the actual Adam update step Delta = (lr * m_hat) / (sqrt(v_hat) + eps)
    Returns mean, max, min of absolute values, and variance of updates.
    """
    all_updates = []
    for group in optim.param_groups:
        beta1, beta2 = group['betas']
        eps = group['eps']
        lr = group['lr']
        amsgrad = group.get('amsgrad', False)

        for p in group['params']:
            if p.grad is None:
                continue
            state = optim.state.get(p, None)
            if state is None or len(state) == 0:
                continue

            # 1. Extract momentum states
            m = state['exp_avg']
            v = state['max_exp_avg_sq'] if amsgrad else state['exp_avg_sq']
            step = state['step']

            # 2. Handle step type compatibility
            step_val = step.item() if torch.is_tensor(step) else step
            if step_val <= 0: continue

            # 3. Compute bias correction coefficients
            bias_correction1 = 1 - beta1 ** step_val
            bias_correction2 = 1 - beta2 ** step_val
            
            # 4. Compute m_hat and v_hat
            m_hat = m / bias_correction1
            v_hat = v / bias_correction2
            
            # 5. Compute actual update increment for this parameter (Element-wise)
            # Formula: Delta = (lr / (sqrt(v_hat) + eps)) * m_hat
            update = lr * m_hat / (torch.sqrt(v_hat) + eps)
            all_updates.append(update.view(-1))
            
    if len(all_updates) == 0:
        return {
            "update_abs_mean": 0.0, "update_abs_max": 0.0,
            "update_abs_min": 0.0, "update_var": 0.0
        }

    # Concatenate updates from all parameters
    flat_updates = torch.cat(all_updates)
    abs_updates = flat_updates.abs()

    return {
        "update_abs_mean": float(abs_updates.mean().item()),
        "update_abs_max": float(abs_updates.max().item()),
        "update_abs_min": float(abs_updates.min().item()),
        "update_var": float(flat_updates.var().item())
    }

@torch.no_grad()
def extract_adam_d_flat(
    optim: torch.optim.Optimizer,
    *,
    strict: bool = True,
    bias_correction: bool = True,
    clamp_bias_correction: float = 1e-16,
    use_amsgrad: Optional[bool] = None,
) -> torch.Tensor:
    """
    Flatten and concatenate d_list -> d_flat: [P]
    """
    d_list = extract_adam_d_like_params(
        optim,
        strict=strict,
        bias_correction=bias_correction,
        clamp_bias_correction=clamp_bias_correction,
        use_amsgrad=use_amsgrad,
    )
    if len(d_list) == 0:
        return torch.empty(0)
    return torch.cat([d.reshape(-1) for d in d_list], dim=0)


def build_Z_matrix(
    critic: nn.Module,
    params_in_order: List[torch.nn.Parameter],
    obss: torch.Tensor,
    actions: torch.Tensor,
) -> torch.Tensor:
    """
    Return Z: [P, B], where the i-th column is dQ_i/dtheta (flattened)

    Key point: params_in_order must match the flattening order of d_flat,
    here we use the parameter order from optimizer.param_groups to construct it.
    """
    B = obss.shape[0]
    q = critic(obss, actions).view(B)  # [B]

    Z_cols: List[torch.Tensor] = []

    # For each sample q[i], compute gradient once to form a column
    for i in range(B):
        grads = torch.autograd.grad(
            q[i],
            params_in_order,
            retain_graph=True,
            create_graph=False,
            allow_unused=True,  # More robust: return None if a parameter is not used in the path
        )
        flat_chunks: List[torch.Tensor] = []
        for g, p in zip(grads, params_in_order):
            if g is None:
                flat_chunks.append(torch.zeros_like(p).reshape(-1))
            else:
                flat_chunks.append(g.reshape(-1))
        g_flat = torch.cat(flat_chunks, dim=0)  # [P]
        Z_cols.append(g_flat)

    Z = torch.stack(Z_cols, dim=1)  # [P, B]
    return Z


@torch.no_grad()
def compute_K_from_Z(d_flat: torch.Tensor, Z1: torch.Tensor, Z2: torch.Tensor) -> torch.Tensor:
    """
    K = Z1^T D Z2 = Z1^T (d * Z2)
    d_flat: [P], Z1/Z2: [P, B]
    return: [B, B]
    """
    return Z1.t() @ (d_flat[:, None] * Z2)


def _max_real_and_corresponding_imag(eigvals: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Given complex eigenvalues:
    Returns:
      - max_real: maximum real part
      - imag_at_max_real: imaginary part of the eigenvalue with maximum real part
      - eig_at_max_real: the complex eigenvalue
      - idx: index of argmax(real)
    """
    real_parts = eigvals.real
    idx = torch.argmax(real_parts)
    max_real = real_parts[idx]
    imag_at_max_real = eigvals.imag[idx]
    eig_at_max_real = eigvals[idx]
    return max_real, imag_at_max_real, eig_at_max_real, idx


def compute_S_for_critic(
    critic: nn.Module,
    optim: torch.optim.Optimizer,
    obss: torch.Tensor,
    actions: torch.Tensor,
    next_obss: torch.Tensor,
    next_actions: torch.Tensor,
    gamma: float,
    *,
    use_td3_target_net: bool = False,
    bias_correction: bool = True,
    strict_d: bool = True,
    clamp_bias_correction: float = 1e-16,
    use_amsgrad: Optional[bool] = None,
    tau: float = 0.005,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Returns:
    - S: [B,B]
    - eigvals_S: eig(S) (complex, may have imaginary parts)
    - eigvals_sym: eig((S+S^T)/2) (real)
    - max_real: maximum real part of eigvals_S
    - imag_at_max_real: imaginary part corresponding to the maximum real part
    - eig_at_max_real: complex eigenvalue with maximum real part
    """
    # Save and restore training mode (avoid permanently setting critic to eval)
    was_training = critic.training
    critic.eval()

    # 1) Get parameters in optimizer.param_groups order (ensure consistency with d order)
    params_in_order = _optimizer_trainable_params_in_order(optim)

    # 2) Extract d_flat
    #    Prefer patched optim.get_d_flat; fallback to extract_adam_d_flat in this file
    if hasattr(optim, "get_d_flat") and callable(getattr(optim, "get_d_flat")):
        d_flat = optim.get_d_flat(
            bias_correction=bias_correction,
            strict=strict_d,
            clamp_bias_correction=clamp_bias_correction,
            use_amsgrad=use_amsgrad,
        )
    else:
        d_flat = extract_adam_d_flat(
            optim,
            strict=strict_d,
            bias_correction=bias_correction,
            clamp_bias_correction=clamp_bias_correction,
            use_amsgrad=use_amsgrad,
        )

    # sanity check: d_flat length must match total parameter count in params_in_order
    P = sum(p.numel() for p in params_in_order)
    if d_flat.numel() != P:
        raise RuntimeError(
            f"d_flat length mismatch: got {d_flat.numel()} but expected {P}. "
            "This usually means the flatten order between d and Z is inconsistent."
        )

    # 3) Construct Z (requires enable_grad)
    with torch.enable_grad():
        Z1 = build_Z_matrix(critic, params_in_order, obss, actions)           # [P,B]
        Z2 = build_Z_matrix(critic, params_in_order, next_obss, next_actions) # [P,B]

    # 4) Compute S
    K_xx = compute_K_from_Z(d_flat, Z1, Z1)  # [B,B]

    if use_td3_target_net:
        # Strict TD3 matching: target network stop-grad, independent of current critic parameters -> self-excitation term is 0
        S = -K_xx
    else:
        # "Same-network bootstrap" control form
        K_x2x = compute_K_from_Z(d_flat, Z2, Z1)
        S = tau * gamma * K_x2x - K_xx

    # print(S.shape)
    # print(K_xx.shape)
    # print(K_x2x.shape)
    # print(gamma)
    # print(Z1.shape)
    # print(Z2.shape)
    # print(d_flat.shape)
    # print(obss.shape)
    # print(actions.shape)
    # print(next_obss.shape)
    # print(next_actions.shape)
    # sys.exit()
    # print(d_flat)
    # print(len(d_flat))
    print(f"d_flat: {min(d_flat)}: {max(d_flat)}")

    # 5) Eigenvalues
    eigvals_S = torch.linalg.eigvals(S)  # complex
    sym = 0.5 * (S + S.t())
    eigvals_sym = torch.linalg.eigvalsh(sym)  # real

    # 6) What you want: maximum real part and its corresponding imaginary part (same eigenvalue)
    max_real, imag_at_max_real, eig_at_max_real, _ = _max_real_and_corresponding_imag(eigvals_S)

    # Restore mode
    if was_training:
        critic.train()

    return S, eigvals_S, eigvals_sym, max_real, imag_at_max_real, eig_at_max_real, min(d_flat), max(d_flat)


# ============================================================
# TD3+BC Policy
# ============================================================

class TD3BCPolicy(TD3Policy):
    """
    TD3+BC <Ref: https://arxiv.org/abs/2106.06860>

    This version adds spectral analysis diagnostics:
    - Every few steps, construct S and compute its eigenvalues
    - Record:
        * Maximum eigenvalue of symmetric part (eigvals_sym.max)
        * Maximum real part max_real and its corresponding imaginary part imag_at_max_real
    """

    def __init__(
        self,
        actor: nn.Module,
        critic1: nn.Module,
        critic2: nn.Module,
        actor_optim: torch.optim.Optimizer,
        critic1_optim: torch.optim.Optimizer,
        critic2_optim: torch.optim.Optimizer,
        tau: float = 0.005,
        gamma: float = 0.99,
        max_action: float = 1.0,
        exploration_noise: Callable = GaussianNoise,
        policy_noise: float = 0.2,
        noise_clip: float = 0.5,
        update_actor_freq: int = 2,
        alpha: float = 2.5,
        scaler: StandardScaler = None,
        spectral_use_td3_target_net: bool = False, # True,
        # ---- spectral diagnostics options ----
        spectral_every: int = 999,  
        spectral_bias_correction: bool = True,
    ) -> None:

        super().__init__(
            actor,
            critic1,
            critic2,
            actor_optim,
            critic1_optim,
            critic2_optim,
            tau=tau,
            gamma=gamma,
            max_action=max_action,
            exploration_noise=exploration_noise,
            policy_noise=policy_noise,
            noise_clip=noise_clip,
            update_actor_freq=update_actor_freq,
        )

        self._alpha = alpha
        self.scaler = scaler

        # diagnostics control
        self.flag = 0
        self.spectral_every = int(spectral_every)
        self.spectral_use_td3_target_net = bool(spectral_use_td3_target_net)
        self.spectral_bias_correction = bool(spectral_bias_correction)

        # init logs to avoid attribute errors
        self._last_actor_loss = 0.0

        self.max_eig_sym_1 = 0.0
        self.max_eig_sym_2 = 0.0

        self.v_min = 0.0
        self.v_mean = 0.0
        self.v_max = 0.0

        self.max_real_1 = 0.0
        self.imag_at_max_real_1 = 0.0
        self.max_real_2 = 0.0
        self.imag_at_max_real_2 = 0.0

        self.update_stats = {
            "update/abs_mean": 0.0, "update/abs_max": 0.0,
            "update/abs_min": 0.0, "update/var": 0.0
        }

    def train(self) -> None:
        self.actor.train()
        self.critic1.train()
        self.critic2.train()

    def eval(self) -> None:
        self.actor.eval()
        self.critic1.eval()
        self.critic2.eval()

    def _sync_weight(self) -> None:
        for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
        for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
        for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)

    def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        # Try to be compatible with both np / torch inputs
        device = next(self.actor.parameters()).device

        if self.scaler is not None:
            if torch.is_tensor(obs):
                obs_np = obs.detach().cpu().numpy()
                obs_np = self.scaler.transform(obs_np)
                obs_t = torch.as_tensor(obs_np, device=device, dtype=torch.float32)
            else:
                obs = self.scaler.transform(obs)
                obs_t = torch.as_tensor(obs, device=device, dtype=torch.float32)
        else:
            if torch.is_tensor(obs):
                obs_t = obs.to(device=device)
            else:
                obs_t = torch.as_tensor(obs, device=device, dtype=torch.float32)

        with torch.no_grad():
            action = self.actor(obs_t).cpu().numpy()

        if not deterministic:
            action = action + self.exploration_noise(action.shape)
            action = np.clip(action, -self._max_action, self._max_action)
        return action

    def learn(self, batch: Dict) -> Dict[str, float]:
        obss = batch["observations"]
        actions = batch["actions"]
        next_obss = batch["next_observations"]
        rewards = batch["rewards"]
        terminals = batch["terminals"]
        # next_a = batch["next_actions"]

        # print(next_a)
        # print(actions)
        # sys.exit()
        # -----------------------
        # update critic1/critic2
        # -----------------------
        q1 = self.critic1(obss, actions)
        q2 = self.critic2(obss, actions)

        with torch.no_grad():
            noise = (torch.randn_like(actions) * self._policy_noise).clamp(-self._noise_clip, self._noise_clip)
            next_actions = (self.actor_old(next_obss) + noise).clamp(-self._max_action, self._max_action)
            next_a = self.actor_old(next_obss) 
            next_q = torch.min(
                self.critic1_old(next_obss, next_actions),
                self.critic2_old(next_obss, next_actions),
            )
            target_q = rewards + self._gamma * (1 - terminals) * next_q

        critic1_loss = (q1 - target_q).pow(2).mean()
        critic2_loss = (q2 - target_q).pow(2).mean()

        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.critic1_optim.param_groups[0]['params'], max_norm=1.0)
        # torch.nn.utils.clip_grad_norm_(self.critic1.parameters(), max_norm=10.0)
        self.critic1_optim.step()
        # project_linear_spectral_norm_(self.critic1, c=1.0, n_iter=1)

        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.critic2.parameters(), max_norm=10.0)
        # torch.nn.utils.clip_grad_norm_(self.critic2_optim.param_groups[0]['params'], max_norm=1.0)
        self.critic2_optim.step()
        # project_linear_spectral_norm_(self.critic2, c=1.0, n_iter=1)

        


    # if self.spectral_every > 0 and (self.flag % self.spectral_every == 0):
            # Get update statistics for Critic1
        # self.update_stats = get_adam_update_stats(self.critic1_optim)
            
            # # If you also want to monitor second moment v, keep your previous logic
            # v_list = []
            # for group in self.critic1_optim.param_groups:
            #     target_key = 'max_exp_avg_sq' if group.get('amsgrad', False) else 'exp_avg_sq'
            #     for p in group['params']:
            #         st = self.critic1_optim.state.get(p, None)
            #         if st is not None and target_key in st:
            #             v_list.append(st[target_key].view(-1))
            # if v_list:
            #     v_flat = torch.cat(v_list)
            #     self.v_min, self.v_max, self.v_mean = v_flat.min().item(), v_flat.max().item(), v_flat.mean().item()

        # -----------------------
        # spectral diagnostics
        # -----------------------
        if self.spectral_every > 0 and (self.flag % self.spectral_every == 0):
            S1, eig1, eig1_sym, max_real1, imag1, eig_max1, d_flat_min1, d_flat_max1 = compute_S_for_critic(
                self.critic1,
                self.critic1_optim,
                obss,
                actions,
                next_obss,
                next_a,
                gamma=self._gamma,
                use_td3_target_net=self.spectral_use_td3_target_net,
                bias_correction=self.spectral_bias_correction,
            )

            S2, eig2, eig2_sym, max_real2, imag2, eig_max2, d_flat_min2, d_flat_max2 = compute_S_for_critic(
                self.critic2,
                self.critic2_optim,
                obss,
                actions,
                next_obss,
                next_a,
                gamma=self._gamma,
                use_td3_target_net=self.spectral_use_td3_target_net,
                bias_correction=self.spectral_bias_correction,
            )

            # Your original: maximum eigenvalue of symmetric part
            self.max_eig_sym_1 = float(eig1_sym.max().item())
            self.max_eig_sym_2 = float(eig2_sym.max().item())

            # What you want: maximum real part + corresponding imaginary part
            self.max_real_1 = float(max_real1.item())
            self.imag_at_max_real_1 = float(imag1.item())
            self.max_real_2 = float(max_real2.item())
            self.imag_at_max_real_2 = float(imag2.item())

            # For debugging (be careful: B=256 can be large)
            # print("critic1 eig_sym:", eig1_sym)
            # print("critic2 eig_sym:", eig2_sym)
            print("critic1 eig at max real:", eig_max1)
            print("critic2 eig at max real:", eig_max2)
            print(self.spectral_use_td3_target_net)
            # sys.exit()

        self.flag += 1

        # -----------------------
        # update actor (delayed)
        # -----------------------
        if self._cnt % self._freq == 0:
            a = self.actor(obss)
            q = self.critic1(obss, a)
            lmbda = self._alpha / q.abs().mean().detach()
            actor_loss = -lmbda * q.mean() + (a - actions).pow(2).mean()

            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            self._last_actor_loss = float(actor_loss.item())
            self._sync_weight()

        self._cnt += 1

        # ============================================================
        # Print Critic1 Optimizer's second moment (exp_avg_sq) statistics
        # ============================================================
        if self.spectral_every > 0 and (self.flag % self.spectral_every == 0):
            v_list = []
            for group in self.critic1_optim.param_groups:
                # Note: if you enabled AMSGrad, you may need to check 'max_exp_avg_sq'
                target_key = 'max_exp_avg_sq' if group.get('amsgrad', False) else 'exp_avg_sq'
                
                for p in group['params']:
                    # Get the state corresponding to this parameter
                    st = self.critic1_optim.state.get(p, None)
                    if st is not None:
                        v = st.get(target_key, None)
                        if v is not None:
                            # Flatten and add to list
                            v_list.append(v.view(-1))
            
            if len(v_list) > 0:
                # Concatenate second moments from all parameters
                v_flat = torch.cat(v_list)
                self.v_min = v_flat.min().item()
                self.v_max = v_flat.max().item()
                self.v_mean = v_flat.mean().item()
                
                print(f"\n[Critic1 Optim Second Moment ({target_key})]")
                print(f"  Min: {self.v_min:.8f}")
                print(f"  Max: {self.v_max:.8f}")
                print(f"  Mean: {self.v_mean:.8f}")
                # If you want to see the first 10 specific values:
                # print(f"  First 10 values: {v_flat[:10].tolist()}")
            else:
                print("\n[Critic1 Optim] State is empty (step() likely not called yet).")

        # Your sys.exit() should be placed after this, or remove it if you want to keep running
        # sys.exit()

        return {
            "loss/actor": float(self._last_actor_loss),
            "loss/critic1": float(critic1_loss.item()),
            "loss/critic2": float(critic2_loss.item()),


            "v/v_min": self.v_min,
            "v/v_max": self.v_max,
            "v/v_mean": self.v_mean,

            # Your original: maximum eigenvalue of symmetric part
            "eig_sym/max_eig_sym_1": float(self.max_eig_sym_1),
            "eig_sym/max_eig_sym_2": float(self.max_eig_sym_2),

            # New addition: actual update step statistics (reflects how far Adam actually moved)
            "update/abs_mean": self.update_stats["update_abs_mean"],
            "update/abs_max": self.update_stats["update_abs_max"],
            "update/abs_min": self.update_stats["update_abs_min"],
            "update/var": self.update_stats["update_var"],

            # New addition: maximum real part + corresponding imaginary part
            "eig_S/max_real_1": float(self.max_real_1),
            "eig_S/imag_at_max_real_1": float(self.imag_at_max_real_1),
            "eig_S/max_real_2": float(self.max_real_2),
            "eig_S/imag_at_max_real_2": float(self.imag_at_max_real_2),
        }
