from __future__ import annotations

from dataclasses import dataclass, field
import numpy as np

from .base import BaseProblem, ToyExampleSol, Array


@dataclass
class NonlinearToyExample(BaseProblem):
    """
    Nonlinear Constraint Toy Example.

    Problem formulation:
        min_{x in X} max_{y in Y} E_w [ -(alpha + w)^T x + (1/(2n)) (e^T y)^2 ]
        s.t. ||y|| - ||x|| <= 0

    where:
        X = [-10, 10]^n
        Y = [-10, 10]^{n}
        alpha in R^n is a fixed vector
        w ~ N(0, sigma^2 I) is Gaussian noise
        e is the vector of ones in R^{n}

    The exact solution is:
        x* = alpha
        y* = ||alpha||/sqrt(n) * e
    """
    alpha: Array = field(init=False)
    e: Array = field(init=False)

    def _setup_problem(self) -> None:
        # Initialize alpha uniformly in [-bound, bound]
        # We assume alpha in X.
        self.x_bounds = (-10.0, 10.0)
        self.y_bounds = (-10.0, 10.0)
        low, high = self.x_bounds
        self.alpha = self.rng.uniform(low, high, size=self.config.dimension)
        self.e = np.ones(self.config.dimension, dtype=float)

    def sample_noise_fresh(self, count: int | None = None) -> Array:
        """Draw fresh Gaussian noise vectors w."""
        num = count if count is not None else self.config.num_noise_samples
        # w ~ N(0, sigma^2)
        # standard_normal gives N(0, 1). Multiply by noise_level (sigma).
        raw = self.rng.standard_normal(size=(num, self.config.dimension), dtype=np.float32)
        raw *= self.config.noise_level
        return np.asarray(raw, dtype=np.float64)

    def generate_random_guess(self) -> ToyExampleSol:
        """Generate a random initial guess within the box constraints."""
        # Random guess might be better for testing convergence
        # low, high = self.x_bounds
        # x0 = self.rng.uniform(low, high, size=self.config.dimension)
        
        # low_y, high_y = self.y_bounds
        # y0 = self.rng.uniform(low_y, high_y, size=2 * self.config.dimension)

        x0 = np.zeros(self.config.dimension, dtype=float)
        y0 = np.ones(self.config.dimension, dtype=float)
        
        return ToyExampleSol(x0, y0)

    def exact_solution(self) -> ToyExampleSol:
        """Closed-form solution given in the formulation."""
        # x* = alpha
        x_star = self.alpha.copy()
        
        # y* = ||x||/sqrt(n) * e
        scale = np.linalg.norm(x_star) / np.sqrt(self.config.dimension)
        y_star = scale * self.e
        
        return ToyExampleSol(x_star, y_star)


    # ---------- Constraints ------------
    def constraint_residual(self, x: Array, y: Array) -> Array:
        """Evaluate ||y|| - ||x|| <= 0.
        """
        val = np.linalg.norm(y) - np.linalg.norm(x)
        return np.array([val])

    def constraint_gradient_x(self, x: Array, y: Array) -> Array:
        """The gradient of constraint w.r.t x
        """
        x_norm = np.linalg.norm(x)
        if abs(x_norm) < 1e-10:
            return np.zeros_like(x)[None, :]
        return (- x / x_norm)[None, :]

    def constraint_gradient_y(self, x: Array, y: Array) -> Array:
        """The gradient of constraint w.r.t y
        """
        y_norm = np.linalg.norm(y)
        if abs(y_norm) < 1e-10:
            return np.zeros_like(y)[None, :]
        return (y / y_norm)[None, :]

    # ----- Gradients of Objective ------

    def gradient_x(self, x: Array, y: Array, *, samples: Array | None) -> Array:
        """
        Compute the stochastic gradient w.r.t. x
        """
        mats = self.sample_noise() if samples is None else samples
        # mats is (num_samples, n)
        
        avg_w = np.mean(mats, axis=0)
        grad = -(self.alpha + avg_w)
        return grad

    def gradient_y(self, x: Array, y: Array, *, samples: Array | None) -> Array:
        """
        Compute the stochastic gradient w.r.t. y
        """
        n = self.config.dimension
        dot_val = np.dot(self.e, y)
        grad_y = (1.0 / n) * dot_val * self.e
        
        return grad_y

    # ------------ Objective ------------

    def stochastic_objective(
        self,
        x: Array,
        y: Array,
        noise_samples: Array,
    ) -> float:
        """
        Compute the average stochastic objective.
        Obj = -(alpha + w)^T x + (1/(2n)) (e^T y)^2
        """
        n = self.config.dimension
        
        avg_w = np.mean(noise_samples, axis=0)
        term1 = -np.dot(self.alpha + avg_w, x)
        
        term2 = (1.0 / (2.0 * n)) * (np.dot(self.e, y) ** 2)
        
        return float(term1 + term2)


if __name__ == "__main__":
    from toy_example.problems.base import ToyExampleCfg

    cfg = ToyExampleCfg(dimension=10, noise_level=1.0, seed=42, num_noise_samples=16)
    problem = NonlinearToyExample(cfg)
    sol = problem.exact_solution()
    noises = problem.sample_noise()

    objective_value = problem.stochastic_objective(sol.x, sol.y, noises)
    residual = problem.constraint_residual(sol.x, sol.y)

    print("---- Exact solution ----")
    print(f"alpha sample: {problem.alpha[:5]}...")
    print("Objective (stochastic average):", f"{objective_value:.2f}")
    print("Constraint error:", f"{np.linalg.norm(residual):.2f}")

    random_guess = problem.generate_random_guess()
    random_objective = problem.stochastic_objective(
        random_guess.x, random_guess.y, noises
    )
    random_residual = problem.constraint_residual(
        random_guess.x, random_guess.y
    )

    print("\n---- Random guess ----")
    print("Objective (stochastic average):", f"{random_objective:.2f}")
    print("Constraint error:", f"{np.linalg.norm(random_residual):.2f}")
