from __future__ import annotations

from dataclasses import dataclass, field
import numpy as np
import torch

from .base import BaseProblem, ToyExampleSol, Array


@dataclass
class LinearToyExample(BaseProblem):
    """
    Toy example with linear constraints.
    """
    A: Array = field(init=False)
    e: Array = field(init=False)

    def _setup_problem(self) -> None:
        self.x_bounds = (-10.0, 10.0)
        self.y_bounds = (-10.0, 10.0)
        self.A = self._make_spd_matrix(self.config.dimension)
        self.e = np.ones(self.config.dimension, dtype=float)
        self._e_t = torch.tensor(self.e, dtype=torch.float64)
        self._A_t = torch.tensor(self.A, dtype=torch.float64)

    def _make_spd_matrix(self, n: int) -> Array:
        """Return a symmetric positive definite matrix so A + I stays SPD."""
        m = self.rng.normal(size=(n, n))
        return (m @ m.T) / n + np.eye(n)

    def generate_random_guess(self) -> ToyExampleSol:
        """Generate a random initial guess within the box constraints."""
        x0 = np.ones(self.config.dimension, dtype=float)
        y0 = np.ones(2 * self.config.dimension, dtype=float)
        return ToyExampleSol(x0, y0)

    def exact_solution(self) -> ToyExampleSol:
        """Closed-form solution given in the formulation."""
        n = self.config.dimension
        a_plus_i_inv = np.linalg.inv(self.A + np.eye(n))
        x_star = - 2.0 * a_plus_i_inv @ self.e
        y1_star = x_star + self.e
        y2_star = - 2.0 * x_star - self.e
        return ToyExampleSol(x_star, np.concatenate([y1_star, y2_star]))

    def _split_y(self, y: Array) -> tuple[Array, Array]:
        """Helper to split unified y into y1 and y2."""
        n = self.config.dimension
        return y[:n], y[n:]

    # ---------- Torch definitions ------------
    def _split_y_torch(self, y_t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        n = self.config.dimension
        return y_t[:n], y_t[n:]

    def _constraints_torch(self, x_t: torch.Tensor, y_t: torch.Tensor) -> torch.Tensor:
        """(1/sqrt(n))*(e^T x + e^T y1 + e^T y2) with symmetric +/- outputs."""
        y1_t, y2_t = self._split_y_torch(y_t)
        resid = (torch.dot(self._e_t, x_t) + torch.dot(self._e_t, y1_t) + torch.dot(self._e_t, y2_t)) / torch.sqrt(torch.tensor(float(self.config.dimension), dtype=torch.float64))
        return torch.stack([resid, -resid])

    def _objective_torch(self, x_t: torch.Tensor, y_t: torch.Tensor, noise_samples_t: torch.Tensor) -> torch.Tensor:
        """Stochastic objective averaged over provided noise draws."""
        if noise_samples_t.ndim == 2:  # allow single matrix
            noise_samples_t = noise_samples_t.unsqueeze(0)

        y1_t, y2_t = self._split_y_torch(y_t)
        # x^T (A + sigma E_i) x averaged over noise
        stoc_vals = 0.5 * torch.einsum("nij,j,i->n", self.config.noise_level * noise_samples_t, x_t, x_t)
        det_val = 0.5 * torch.dot(x_t, self._A_t @ x_t)
        quad_term = stoc_vals.mean() + det_val

        det_terms = - (0.5 * torch.dot(y1_t, y1_t) - torch.dot(x_t, y1_t) + torch.dot(self._e_t, y2_t) + 0.5 * torch.norm(y2_t + 2.0 * x_t + self._e_t) ** 2)
        return quad_term + det_terms
    
    def gradient_y(self, x: Array, y: Array, *, samples: Array | None) -> Array:
        # y-grad is independent of noise in current examples; avoid advancing RNG
        samples_np = np.zeros((1, self.config.dimension, self.config.dimension), dtype=float)
        _, grad_y = self._autograd_objective_grads(x, y, samples=samples_np)
        return grad_y


if __name__ == "__main__":
    from toy_example.problems.base import ToyExampleCfg

    cfg = ToyExampleCfg(dimension=10, noise_level=1.0, seed=42, num_noise_samples=16)
    problem = LinearToyExample(cfg)
    sol = problem.exact_solution()
    noises = problem.sample_noise()

    objective_value = problem.stochastic_objective(sol.x, sol.y, noises)
    residual = problem.constraint_residual(sol.x, sol.y)

    print("---- Exact solution ----")
    print("Objective (stochastic average):", f"{objective_value:.2f}")
    print("Constraint residual:", f"{[round(float(r), 2) for r in residual]}")

    random_guess = problem.generate_random_guess()
    random_objective = problem.stochastic_objective(
        random_guess.x, random_guess.y, noises
    )
    random_residual = problem.constraint_residual(
        random_guess.x, random_guess.y
    )

    print("\n---- Random guess ----")
    print("Objective (stochastic average):", f"{random_objective:.2f}")
    print("Constraint residual:", f"{[round(float(r), 2) for r in random_residual]}")
