from __future__ import annotations

from typing import Optional, Callable
from dataclasses import dataclass
import numpy as np

from toy_example.problems import ToyExampleSol
from toy_example.common import GradEstimatorX, GradEstimatorY, ConstraintFn, Projection, Array, BasicAlgoConfig
from toy_example.algorithms.log_utils import BaseLogger


# ---- Config and state -----------------
@dataclass(slots=True)
class AlgoConfig(BasicAlgoConfig):
    max_iters: int = 1000
    log_every: int = 100
    seed: int | None = 42

    alpha: float | Callable[[int], float] = 1e-3  # x step
    beta: float | Callable[[int], float] = 1e-2   # y step
    eta: float | Callable[[int], float] = 0.8     # momentum blend (STORM)
    prox: float | Callable[[int], float] = 0.1   # prox parameter for y
    rho: float | Callable[[int], float] = 1.0    # penalty parameter (schedule or constant)

    is_storm: bool = True  # whether to use STORM-style momentum

    def alpha_at(self, k: int) -> float: return self.alpha(k) if callable(self.alpha) else float(self.alpha)
    def beta_at(self, k: int) -> float: return self.beta(k) if callable(self.beta) else float(self.beta)
    def eta_at(self, k: int) -> float: return self.eta(k) if callable(self.eta) else float(self.eta)
    def prox_at(self, k: int) -> float: return self.prox(k) if callable(self.prox) else float(self.prox)
    def rho_at(self, k: int) -> float: return self.rho(k) if callable(self.rho) else float(self.rho)


@dataclass(slots=True)
class AlgoState:
    x: Array
    y: Array
    iter: int = 1
    x_prev: Array | None = None  # previous x, for STORM momentum
    y_prev: Array | None = None  # previous y, for STORM momentum
    gx_prev: Array | None = None  # previous gradient of x, for STORM momentum
    gy_prev: Array | None = None  # previous gradient of y, for STORM momentum


# ---- Algorithm implementation ---------
class SPACO:
    """
    Minimal SPACO loop; plug in your oracles for gradients, constraints, and projections.
    Ambiguous formulae from the PDF (e.g., exact penalty/multiplier tweaks) are TODOs.
    """

    def __init__(
        self,
        *,
        cfg: AlgoConfig,
        proj_x: Projection,
        proj_y: Projection,
        constraint: ConstraintFn,
        grad_y: GradEstimatorY,
        grad_x: GradEstimatorX,
        grad_cx: Callable[[Array, Array], Array],
        grad_cy: Callable[[Array, Array], Array],
        noise_sampler: Callable[[Optional[int]], Array],
        objective_fn: Callable[[Array, Array], float],
        exact_sol: Optional[ToyExampleSol] = None,
    ) -> None:
        self.cfg = cfg
        self.proj_x = proj_x
        self.proj_y = proj_y
        self.constraint = constraint
        self.grad_y = grad_y
        self.grad_x = grad_x
        self.grad_cx = grad_cx
        self.grad_cy = grad_cy
        self.noise_sampler = noise_sampler
        self.objective_fn = objective_fn
        self.exact_sol = exact_sol
        self.logger = BaseLogger(objective_fn, constraint, exact_sol)

    def run(self, init: ToyExampleSol):
        state = AlgoState(init.x.copy(), init.y.copy())
        self._log(0, state)  # log initial state

        for k in range(1, self.cfg.max_iters + 1):  # iteration index starts from 1
            state.iter = k
            state = self._step(state)

            if self.cfg.log_every and k % self.cfg.log_every == 0:
                self._log(k, state)
        return state

    def _step(self, state: AlgoState) -> AlgoState:
        k = state.iter
        alpha, beta, eta, prox, rho = (
            self.cfg.alpha_at(k),
            self.cfg.beta_at(k),
            self.cfg.eta_at(k),
            self.cfg.prox_at(k),
            self.cfg.rho_at(k),
        )
        x_prev, y_prev = state.x_prev, state.y_prev

        y_samples = None  # warning: current self.grad_y() does not need y_samples, just set it to None to speed up
        x_samples = self.noise_sampler(None)

        # y-update
        org_gy = self._compute_spaco_grad_y(state.x, state.y, rho=rho, prox=prox, samples=y_samples)

        if state.gy_prev is None or not self.cfg.is_storm:
            gy = org_gy
        else:
            assert x_prev is not None and y_prev is not None, "x_prev, y_prev must be set if gy_prev is set"
            gy_prev_corr = self._compute_spaco_grad_y(x_prev, y_prev, rho=rho, prox=prox, samples=y_samples)
            gy = org_gy + (1.0 - eta) * (state.gy_prev - gy_prev_corr)

        y_new = self.proj_y(state.y + beta * gy)

        # x-update
        org_gx = self._compute_spaco_grad_x(state.x, y_new, rho=rho, samples=x_samples)

        if state.gx_prev is None or not self.cfg.is_storm:
            gx = org_gx
        else:
            assert x_prev is not None, "x_prev must be set if gx_prev is set"
            gx_prev_corr = self._compute_spaco_grad_x(x_prev, state.y, rho=rho, samples=x_samples)
            gx = org_gx + (1.0 - eta) * (state.gx_prev - gx_prev_corr)

        x_new = self.proj_x(state.x - alpha * gx)

        algo_state = AlgoState(
            x=x_new,
            y=y_new,
            x_prev=state.x,
            y_prev=state.y,
            gx_prev=gx,
            gy_prev=gy,
            iter=state.iter + 1,
        )

        return algo_state

    def _compute_spaco_grad_x(self, x: Array, y: Array, *, rho: float, samples: Optional[Array] = None) -> Array:
        grad = self.grad_x(x, y, samples=samples)

        c_val = self.constraint(x, y)
        c_val = np.clip(c_val, 0.0, None)
        c_grad_x = self.grad_cx(x, y)
        grad -= rho * np.einsum("i,ij->j", c_val, c_grad_x)
        return grad

    def _compute_spaco_grad_y(self, x: Array, y: Array, *, rho: float, prox: float, samples: Optional[Array] = None) -> Array:
        # objective gradient
        grad_y = self.grad_y(x, y, samples=samples)

        # constraint gradient
        c_val = self.constraint(x, y)
        c_val = np.clip(c_val, 0.0, None)
        c_grad_y = self.grad_cy(x, y)
        grad_y -= rho * np.einsum("i,ij->j", c_val, c_grad_y)

        # proximal term
        grad_y -= prox * y

        return grad_y

    def _log(self, step: int, state: AlgoState) -> None:
        extra_metrics = {
            "params/rho": self.cfg.rho_at(state.iter),
            "params/prox": self.cfg.prox_at(state.iter),
            "params/alpha": self.cfg.alpha_at(state.iter),
            "params/beta": self.cfg.beta_at(state.iter),
            "params/eta": self.cfg.eta_at(state.iter),
        }

        extra_log_items = [
            f"rho={self.cfg.rho_at(state.iter):.2f}",
        ]

        self.logger.log(step, state, extra_metrics, extra_log_items)

        if state.iter >= self.cfg.max_iters - 5 * self.cfg.log_every:
            with np.printoptions(precision=2, suppress=False):
                print("x:", state.x)
                print("y:", state.y)
