from __future__ import annotations

from dataclasses import dataclass, field
import numpy as np
import torch

from .base import BaseProblem, ToyExampleSol, Array


@dataclass
class NonlinearToyExample(BaseProblem):
    """
    Nonlinear Constraint Toy Example.
    """
    e: Array = field(init=False)

    def _setup_problem(self) -> None:
        self.x_bounds = (-10.0, 10.0)
        self.y_bounds = (-10.0, 10.0)

        self.config.dimension = 2  # Override dimension to 2 for this problem.

        self.e = np.ones(self.config.dimension, dtype=float)
        self._e_t = torch.tensor(self.e, dtype=torch.float64)

        self.e2 = np.array([1.0, 0.0], dtype=float)  # For 2D case
        self._e2_t = torch.tensor(self.e2, dtype=torch.float64)

    def sample_noise_fresh(self, count: int | None = None) -> Array:
        """Draw fresh Gaussian noise vectors w."""
        num = count if count is not None else self.config.num_noise_samples
        # w ~ N(0, sigma^2)
        # standard_normal gives N(0, 1). Multiply by noise_level (sigma).
        raw = self.rng.standard_normal(size=(num, self.config.dimension), dtype=np.float32)
        raw *= self.config.noise_level
        return np.asarray(raw, dtype=np.float64)

    def generate_random_guess(self) -> ToyExampleSol:
        """Generate a random initial guess within the box constraints."""
        x0 = self.rng.uniform(self.x_bounds[0], self.x_bounds[1], size=self.config.dimension)
        y0 = self.rng.uniform(self.y_bounds[0], self.y_bounds[1], size=self.config.dimension)
        return ToyExampleSol(x0, y0)

    def exact_solution(self) -> ToyExampleSol:
        """Closed-form solution given in the formulation."""
        x_star = self.e2.copy()
        y_star = np.array([1.0, 1.0], dtype=float)  # For 2D case
        return ToyExampleSol(x_star, y_star)

    # ---------- Torch definitions ------------
    def _constraints_torch(self, x_t: torch.Tensor, y_t: torch.Tensor) -> torch.Tensor:
        """c(x, y) = [x^T y - 1.0; x[::-1]^T y - 1.0] <= 0"""
        term1 = torch.dot(x_t, y_t) - 1.0
        term2 = torch.dot(x_t.flip(dims=[0]), y_t) - 1.0
        return torch.stack([term1, term2])

    def _objective_torch(self, x_t: torch.Tensor, y_t: torch.Tensor, noise_samples_t: torch.Tensor) -> torch.Tensor:
        """Objective: 5 ||x - e2||^2 + (x + x[::-1])^T y"""
        term1 = 5.0 * torch.sum((x_t - self._e2_t) ** 2)
        term2 = torch.dot(x_t + x_t.flip(dims=[0]), y_t)
        return term1 + term2
    
    def gradient_y(self, x: Array, y: Array, *, samples: Array | None) -> Array:
        # y-grad is independent of noise in current examples; avoid advancing RNG
        samples_np = np.zeros((1, self.config.dimension), dtype=float)
        _, grad_y = self._autograd_objective_grads(x, y, samples=samples_np)
        return grad_y

    # Override projection to enforce X = [1, 10]^n
    def project_x(self, vector: Array) -> Array:
        low, high = self.x_bounds
        return np.clip(vector, low, high)

    def project_y(self, vector: Array) -> Array:
        low, high = self.y_bounds
        return np.clip(vector, low, high)


if __name__ == "__main__":
    from toy_example.problems.base import ToyExampleCfg

    cfg = ToyExampleCfg(dimension=10, noise_level=1.0, seed=42, num_noise_samples=16)
    problem = NonlinearToyExample(cfg)
    sol = problem.exact_solution()
    noises = problem.sample_noise()

    objective_value = problem.stochastic_objective(sol.x, sol.y, noises)
    residual = problem.constraint_residual(sol.x, sol.y)

    print("---- Exact solution ----")
    print("Objective (stochastic average):", f"{objective_value:.2f}")
    print("Constraint error:", f"{np.linalg.norm(residual):.2f}")

    random_guess = problem.generate_random_guess()
    random_objective = problem.stochastic_objective(
        random_guess.x, random_guess.y, noises
    )
    random_residual = problem.constraint_residual(
        random_guess.x, random_guess.y
    )

    print("\n---- Random guess ----")
    print("Objective (stochastic average):", f"{random_objective:.2f}")
    print("Constraint error:", f"{np.linalg.norm(random_residual):.2f}")
