import math
from typing import Iterable, Tuple, Optional

import torch
from torch.optim.optimizer import Optimizer, required


class Fiber(Optimizer):
    """
    AdamW + residual (innovation) temporal denoising + optional filter-aware DP-AdamBC correction.

    IMPORTANT (two-point / DiSK style):
    - Call `prestep(closure)` only when you REALLY want to use the two-point rule for that update.
    - `step()` will only apply the z-normalization (and z^2 scaling of the subtracted noise variance)
      if `prestep()` actually ran and had a valid direction (`kf_d_t`) available.

    DP note:
    - `step()` does NOT call `closure()`. It only consumes already-prepared gradients (p.private_grad / p.grad).
    """

    def __init__(
        self,
        params: Iterable[torch.nn.Parameter],
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 0.01,
        amsgrad: bool = False,
        maximize: bool = False,
        kappa: float = 0.6,
        gamma: float = 0.7,
        omega: float = 0.9,
        dp_noise_std: Optional[float] = None,   # std of DP noise BEFORE averaging (C×σ)
        batch_size: Optional[int] = None,       # logical batch size used by the DP engine for averaging
        use_filter_aware_adambc: bool = True,
        v_min: float = 1e-8,                    # variance floor
    ):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if eps <= 0.0:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        if not (0.0 <= omega <= 1.0):
            raise ValueError(f"Invalid omega: {omega}")
        if dp_noise_std is not None and dp_noise_std < 0.0:
            raise ValueError(f"dp_noise_std must be >= 0, got {dp_noise_std}")
        if batch_size is not None and batch_size <= 0:
            raise ValueError(f"batch_size must be > 0, got {batch_size}")
        if v_min < 0.0:
            raise ValueError(f"v_min must be >= 0, got {v_min}")

        # ---- Two-point DiSK setup (same logic as KFOptimizer wrapper) ----
        self.compute_grad = True
        self.scaling_factor = 1.0
        if gamma == 0.0:
            gamma = (1 - kappa) / kappa
            self.compute_grad = False
        elif abs(gamma - (1 - kappa) / kappa) < 1e-3:
            gamma = (1 - kappa) / kappa
            self.compute_grad = False
        else:
            # scaling_factor used to reweight gradients in pre-step
            self.scaling_factor = (gamma * kappa + kappa - 1.0) / (1.0 - kappa)
            self.compute_grad = True

        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            amsgrad=amsgrad,
            maximize=maximize,
            kappa=kappa,
            gamma=gamma,
            omega=omega,
            dp_noise_std=dp_noise_std,
            batch_size=batch_size,
            use_filter_aware_adambc=use_filter_aware_adambc,
            v_min=v_min,
        )
        super().__init__(params, defaults)

        # Flag: set True by prestep() when two-point was actually applied this update
        self._used_two_point: bool = False

    def __setstate__(self, state):
        super().__setstate__(state)

    @torch.no_grad()
    def set_dp_noise_std(self, dp_noise_std: Optional[float]) -> None:
        """Set per-coordinate DP noise std BEFORE averaging (C×σ)."""
        for group in self.param_groups:
            group["dp_noise_std"] = None if dp_noise_std is None else float(dp_noise_std)

    def prestep(self, closure=required):
        """
        Optional two-point DiSK-style rule.
        This method may call closure() once or twice depending on self.compute_grad.

        We also set a flag so that step() can correctly apply z-normalization ONLY when prestep ran.
        """
        # Default: assume not used; we'll set True only if we truly do the two-point procedure.
        self._used_two_point = False

        # Check if we have a direction available (kf_d_t) for ANY parameter.
        # If not, we cannot do two-point (avoid doing two identical backward passes).
        has_direction = False
        for group in self.param_groups:
            for p in group["params"]:
                st = self.state[p]
                if "kf_d_t" in st:
                    has_direction = True
                    break
            if has_direction:
                break

        # If no direction yet, just do a single closure at current params and return.
        if not has_direction:
            with torch.enable_grad():
                return closure()

        # Two-point is active for this update.
        self._used_two_point = True

        loss_out = None
        gamma = self.defaults["gamma"]

        # First evaluation at x_t (only if compute_grad=True)
        if self.compute_grad:
            with torch.enable_grad():
                loss_out = closure()

        # Perturb along previous direction and (if compute_grad) scale current grads
        with torch.no_grad():
            for group in self.param_groups:
                for p in group["params"]:
                    state = self.state[p]
                    if "kf_d_t" not in state:
                        continue

                    # x_t -> x_t + gamma d_{t-1}
                    p.data.add_(state["kf_d_t"], alpha=gamma)

                    if self.compute_grad:
                        # scale existing gradient to emulate the paper's weighted combination trick
                        if hasattr(p, "private_grad") and p.private_grad is not None:
                            p.private_grad.mul_(self.scaling_factor)
                        elif p.grad is not None:
                            p.grad.mul_(self.scaling_factor)
                        else:
                            # No gradient for this parameter (likely unused/frozen); skip scaling
                            continue

        # Second evaluation at x_t + gamma d_{t-1}
        with torch.enable_grad():
            if self.compute_grad:
                closure()
            else:
                # single-backward variant uses only this evaluation
                loss_out = closure()

        # Undo perturbation and undo scaling
        with torch.no_grad():
            for group in self.param_groups:
                for p in group["params"]:
                    state = self.state[p]
                    if "kf_d_t" not in state:
                        continue

                    # x_t + gamma d_{t-1} -> x_t
                    p.data.add_(state["kf_d_t"], alpha=-gamma)

                    if self.compute_grad:
                        if hasattr(p, "private_grad") and p.private_grad is not None:
                            p.private_grad.div_(self.scaling_factor)
                        elif p.grad is not None:
                            p.grad.div_(self.scaling_factor)

        return loss_out

    @torch.no_grad()
    def step(self, closure: Optional[callable] = None):
        """
        Single optimization step.
        DP-safe: does NOT call closure().
        """
        loss = None

        # Gate z-normalization by whether prestep() actually ran for this update
        used_two_point = bool(getattr(self, "_used_two_point", False))
        z = 1.0
        if used_two_point and self.compute_grad:
            z = 1.0 + 1.0 / self.scaling_factor  # same z used in KFOptimizer wrapper

        for group in self.param_groups:
            lr = group["lr"]
            beta1, beta2 = group["betas"]
            eps = group["eps"]
            weight_decay = group["weight_decay"]
            amsgrad = group["amsgrad"]
            maximize = group["maximize"]
            omega = group["omega"]

            dp_noise_std = group.get("dp_noise_std", None)
            batch_size = group.get("batch_size", None)
            use_bc = bool(group.get("use_filter_aware_adambc", True))
            v_min = float(group.get("v_min", 0.0))

            # ---- compute scalar filtered DP-noise variance (same for all tensors in this group) ----
            noise_var_filtered = None
            if use_bc and dp_noise_std is not None and dp_noise_std > 0.0:
                var_factor = (2.0 - omega) / (4.0 - 3.0 * omega)  # A(omega)

                if batch_size is not None and batch_size > 0:
                    noise_var_filtered = (float(dp_noise_std) ** 2) * var_factor / (batch_size ** 2)
                else:
                    # fallback (only correct if dp_noise_std already accounts for averaging)
                    noise_var_filtered = (float(dp_noise_std) ** 2) * var_factor

                # If we scaled gradients by 1/z, scale variance by 1/z^2 as well
                if used_two_point and self.compute_grad and z != 1.0:
                    noise_var_filtered = noise_var_filtered / (z ** 2)

            # ---- collect grads and states ----
            params_with_grad = []
            grads = []
            states = []
            has_private_grad = False

            for p in group["params"]:
                if hasattr(p, "private_grad") and p.private_grad is not None:
                    g = p.private_grad.detach().clone()
                    has_private_grad = True
                elif p.grad is not None:
                    g = p.grad.detach().clone()
                else:
                    continue

                if g.is_sparse:
                    raise RuntimeError("Fiber does not support sparse gradients")

                # Apply two-point normalization ONLY if prestep actually ran (and compute_grad=True)
                if used_two_point and self.compute_grad and z != 1.0:
                    g.div_(z)

                params_with_grad.append(p)
                grads.append(g)
                states.append(self.state[p])

            if len(params_with_grad) == 0:
                continue

            # ---- First pass: residual denoising, store grad_use and norm ----
            denoised_grads = []
            total_filtered_norm_sq = 0.0

            for p, g, state in zip(params_with_grad, grads, states):
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    if amsgrad:
                        state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    state["temp_grad"] = torch.zeros_like(g)       # g_tilde_{t-1}
                    state["residual_ema"] = torch.zeros_like(g)    # r_{t-1}
                    state["kf_d_t"] = torch.zeros_like(p.data)     # direction buffer (for prestep)

                # direction buffer for next prestep (d_t = x_{t+1} - x_t)
                state["kf_d_t"] = -p.data.detach().clone()

                g_tilde_prev = state["temp_grad"]
                r_prev = state["residual_ema"]

                # innovation / residual
                delta_t = g - g_tilde_prev
                r_t = (1.0 - omega) * r_prev + omega * delta_t
                g_tilde_t = g_tilde_prev + r_t

                # store filter state
                state["residual_ema"] = r_t.detach().clone()
                state["temp_grad"] = g_tilde_t.detach().clone()

                grad_use = -g_tilde_t if maximize else g_tilde_t

                denoised_grads.append(grad_use)
                total_filtered_norm_sq += grad_use.norm().pow(2).item()

            # Optional global normalization (only for non-DP .grad mode)
            if total_filtered_norm_sq > 0 and not has_private_grad:
                total_filtered_norm = math.sqrt(total_filtered_norm_sq)
                for i in range(len(denoised_grads)):
                    denoised_grads[i].div_(total_filtered_norm)

            # ---- Second pass: AdamW update using denoised grads ----
            for p, grad_use, state in zip(params_with_grad, denoised_grads, states):
                exp_avg = state["exp_avg"]
                exp_avg_sq = state["exp_avg_sq"]
                if amsgrad:
                    max_exp_avg_sq = state["max_exp_avg_sq"]

                state["step"] += 1
                t = state["step"]

                # decoupled weight decay
                if weight_decay != 0.0:
                    p.data.add_(p.data, alpha=-lr * weight_decay)

                # moments
                exp_avg.mul_(beta1).add_(grad_use, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad_use, grad_use, value=1.0 - beta2)

                if amsgrad:
                    torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    v = max_exp_avg_sq
                else:
                    v = exp_avg_sq

                bc1 = 1.0 - beta1 ** t
                bc2 = 1.0 - beta2 ** t

                v_hat = v / bc2
                if use_bc and noise_var_filtered is not None:
                    v_hat = v_hat - noise_var_filtered
                    v_hat = torch.clamp(v_hat, min=v_min)  # variance floor

                denom = v_hat.sqrt().add_(eps)
                step_size = lr / bc1

                p.data.addcdiv_(exp_avg, denom, value=-step_size)

            # finalize direction (kf_d_t = x_{t+1} - x_t)
            for p in group["params"]:
                st = self.state[p]
                if "kf_d_t" in st:
                    st["kf_d_t"].add_(p.data)

        # reset flag: by default next step is 1-point unless prestep sets it again
        self._used_two_point = False
        return loss
