"""Multiplier Gradient Descent (MGD)

Tsaknakis, Ioannis C. et al.
“Minimax Problems with Coupled Linear Constraints: Computational Complexity and Duality.”
SIAM J. Optim. 33 (2023): 2675-2702.
"""

from __future__ import annotations

from typing import Optional, Callable
from dataclasses import dataclass
import numpy as np

from toy_example.problems import ToyExampleSol
from toy_example.common import GradEstimatorX, GradEstimatorY, ConstraintFn, Projection, Array, BasicAlgoConfig
from toy_example.algorithms.log_utils import BaseLogger


# ---- Config and state -----------------
@dataclass(slots=True)
class AlgoConfig(BasicAlgoConfig):
    max_iters: int = 1000
    log_every: int = 100
    seed: int | None = 42

    alpha: float | Callable[[int], float] = 1e-3  # x step
    beta: float | Callable[[int], float] = 1e-2   # y step
    dual_stepsize: float = 1.0  # dual step
    inner_steps: int = 1 # number of inner steps for each outer step

    def alpha_at(self, k: int) -> float: return self.alpha(k) if callable(self.alpha) else float(self.alpha)
    def beta_at(self, k: int) -> float: return self.beta(k) if callable(self.beta) else float(self.beta)


@dataclass(slots=True)
class AlgoState:
    x: Array
    y: Array
    lam: Array
    iter: int = 1


# ---- Algorithm implementation ---------
class MGD:
    """
    Multiplier Gradient Descent (MGD)
    """

    def __init__(
        self,
        *,
        cfg: AlgoConfig,
        proj_x: Projection,
        proj_y: Projection,
        constraint: ConstraintFn,
        grad_y: GradEstimatorY,
        grad_x: GradEstimatorX,
        grad_cx: Callable[[Array, Array], Array],
        grad_cy: Callable[[Array, Array], Array],
        noise_sampler: Callable[[Optional[int]], Array],
        objective_fn: Callable[[Array, Array], float],
        exact_sol: Optional[ToyExampleSol] = None,
    ) -> None:
        self.cfg = cfg
        self.proj_x = proj_x
        self.proj_y = proj_y
        self.constraint = constraint
        self.grad_y = grad_y
        self.grad_x = grad_x
        self.grad_cx = grad_cx
        self.grad_cy = grad_cy
        self.noise_sampler = noise_sampler
        self.objective_fn = objective_fn
        self.exact_sol = exact_sol
        self.logger = BaseLogger(objective_fn, constraint, exact_sol)

    def run(self, init: ToyExampleSol):
        # Initialize state
        init_lam = np.zeros_like(self.constraint(init.x, init.y))
        state = AlgoState(init.x.copy(), init.y.copy(), init_lam)
        self._log(0, state) # log initial state

        for k in range(1, self.cfg.max_iters + 1):
            state.iter = k
            state = self._step(state)

            if self.cfg.log_every and k % self.cfg.log_every == 0:
                self._log(k, state)
        return state

    def _step(self, state: AlgoState) -> AlgoState:
        
        # Get hyperparameters ...
        k = state.iter
        inner_steps = self.cfg.inner_steps
        x_stepsize = self.cfg.alpha_at(k)
        y_stepsize = self.cfg.beta_at(k)
        dual_stepsize = self.cfg.dual_stepsize
        
        # Primal Update
        x_new, y_new, lam = state.x, state.y, state.lam

        for _ in range(inner_steps):
            # Gradient Ascent for y
            gy = self._compute_grad_y(x_new, y_new, lam, samples=None)
            y_new = self.proj_y(y_new + y_stepsize * gy)

            # Gradient Descent for x
            x_samples = self.noise_sampler(None)
            gx = self._compute_grad_x(x_new, y_new, lam, samples=x_samples)
            x_new = self.proj_x(x_new - x_stepsize * gx)
        
        # Dual Update
        con = self.constraint(x_new, y_new)
        lam_new = np.clip(lam + dual_stepsize * con, 0, None)

        # Next state
        algo_state = AlgoState(
            x=x_new, y=y_new, lam=lam_new, iter=state.iter + 1,
        )

        return algo_state

    def _compute_grad_x(self, x: Array, y: Array, lam: Array, *, samples: Optional[Array] = None) -> Array:
        grad = self.grad_x(x, y, samples=samples)
        grad -= lam @ self.grad_cx(x, y)
        return grad

    def _compute_grad_y(self, x: Array, y: Array, lam: Array, *, samples: Optional[Array] = None) -> Array:
        grad_y = self.grad_y(x, y, samples=samples)
        c_grad_y = self.grad_cy(x, y)
        grad_y -= lam @ c_grad_y
        return grad_y

    def _log(self, step: int, state: AlgoState) -> None:
        lam_norm = np.linalg.norm(state.lam)
        
        extra_metrics = {
            "metrics/lam_norm": lam_norm,
        }
        
        extra_log_items = [
            f"||lam||={lam_norm:.2e}",
        ]
        
        self.logger.log(step, state, extra_metrics, extra_log_items)

        if state.iter >= self.cfg.max_iters - 5 * self.cfg.log_every:
            with np.printoptions(precision=2, suppress=False):
                print("x:", state.x)
                print("y:", state.y)
                print("lam:", state.lam)
