# optimization/gradient_correction.py
"""
Inner-loop gradient correction optimizer for EGP.

Provides:
  - eta_t schedule: eta0 * (t/T)^gamma
  - optional Langevin noise addition
  - gradient clipping and early stop when ||g||2 <= delta_norm

API:
  x_new = apply_corrections(x, grad_fn, step_idx, total_steps, config, rng)
where grad_fn returns (energy, grad) for the current x.
"""
import torch
import math
import numpy as np

def eta_schedule(eta0, gamma, step_idx, total_steps):
    """
    Compute per-step step size: eta0 * (step_idx / total_steps)^gamma
    step_idx is 1-based (1..total_steps)
    """
    frac = (step_idx / float(max(1, total_steps)))
    return eta0 * (frac ** gamma)

def apply_corrections(x, energy_fn, negatives_embs, x_target, cfg, t_idx, ne=None, rng=None):
    """
    Run inner-loop corrections on latent x.

    Params:
      x: torch.Tensor [B, D] (requires_grad=False). Will be mutated/returned as new tensor.
      energy_fn: callable energy_fn.energy_and_grad(x, x_target, negatives_embs) -> (energy, grad)
      cfg: Config instance (has eta0, gamma, ne, delta_norm, use_langevin, langevin_scale)
      t_idx: current outer timestep index (int)
      ne: number of inner steps override (if None use cfg.ne)
      rng: optional random generator

    Returns:
      x_out: corrected tensor (same shape)
      diagnostics: dict (final energy, grad_norm, steps_taken)
    """
    device = x.device
    ne = ne or cfg.ne
    x_cur = x.clone().detach().to(device).requires_grad_(True)
    total_steps = ne
    final_energy = None
    steps_taken = 0
    for inner_i in range(1, ne + 1):
        eta = eta_schedule(cfg.eta0, cfg.gamma, inner_i, total_steps)
        energy, grad = energy_fn.energy_and_grad(x_cur, x_target_latent=x_target, negatives_text_embs=negatives_embs)
        grad_norm = grad.norm(p=2)
        # early stop if gradient very small
        if grad_norm.item() <= cfg.delta_norm:
            final_energy = energy.detach()
            steps_taken = inner_i - 1
            break
        # gradient step (descent on energy)
        step = -eta * grad
        # optional Langevin noise
        if cfg.use_langevin:
            noise = torch.randn_like(step) * math.sqrt(cfg.langevin_scale * eta)
            step = step + noise.to(device)
        x_cur = (x_cur + step).detach().requires_grad_(True)
        final_energy = energy.detach()
        steps_taken = inner_i

    diagnostics = {"final_energy": float(final_energy.detach().cpu().item()) if final_energy is not None else None,
                   "grad_norm": float(grad_norm.detach().cpu().item()),
                   "steps_taken": int(steps_taken)}
    return x_cur.detach(), diagnostics
