"""
Base class for the toy minimax instance.
"""

from __future__ import annotations
from dataclasses import dataclass, field
from abc import ABC, abstractmethod

import numpy as np
import torch
Array = np.ndarray


@dataclass(slots=True)
class ToyExampleCfg:
    """Configuration of the toy minimax instance."""
    dimension: int = 10
    noise_level: float = 1.0
    seed: int | None = 42
    num_noise_samples: int = 16 # plays the role like to default batchsize
    noise_pool_refresh_threshold: int = 0 # interval of noise pool refreshing, 0 means no cache for noise

@dataclass(slots=True)
class ToyExampleSol:
    """Container for the primal-dual variables."""
    x: Array
    y: Array


@dataclass
class BaseProblem(ABC):
    """Abstract base class for toy problems."""

    config: ToyExampleCfg
    rng: np.random.Generator = field(init=False)
    
    def __post_init__(self) -> None:
        self.rng = np.random.default_rng(self.config.seed)
        self._torch_device = torch.device("cpu")
        self._noise_pool: Array | None = None
        self._pool_ptr: int = 0
        self._pool_usage_count: int = 0
        self._pool_refresh_threshold: int = self.config.noise_pool_refresh_threshold
        self._setup_problem()

    def _setup_problem(self) -> None:
        """Hook for subclass initialization."""
        pass

    # ------------------------------------------------------------------
    # Noise samplers
    # ------------------------------------------------------------------
    def sample_noise_fresh(self, count: int | None = None) -> Array:
        """Draw fresh symmetric noise matrices E_i (Slow, for initialization)."""
        # Default implementation for matrix noise, can be overridden
        num = count if count is not None else self.config.num_noise_samples
        raw = self.rng.standard_normal(size=(num, self.config.dimension, self.config.dimension), dtype=np.float32)
        raw += np.transpose(raw, (0, 2, 1))
        raw *= 0.5
        return np.asarray(raw, dtype=np.float64)

    def sample_noise(self, count: int | None = None) -> Array:
        """Fetch noise from a cached pool (Fast, simulates randomness)."""
        if self._pool_refresh_threshold <= 1:
            return self.sample_noise_fresh(count=count)
        
        num = count if count is not None else self.config.num_noise_samples
        
        # Initialize pool if needed (Lazy init) or Refresh pool
        if self._noise_pool is None or self._pool_usage_count >= self._pool_refresh_threshold:
            # Cache size: at least 128, or 5x the requested batch to allow some variation
            pool_size = max(128, num * 5)
            self._noise_pool = self.sample_noise_fresh(count=pool_size)
            self._pool_ptr = 0
            self._pool_usage_count = 0
            
        # Cycle through the pool
        start = self._pool_ptr
        end = start + num
        
        if end <= len(self._noise_pool):
            batch = self._noise_pool[start:end]
            self._pool_ptr = end
        else:
            # Wrap around: just reset to beginning for simplicity and contiguous memory
            self._pool_ptr = num
            batch = self._noise_pool[0:num]
            
        self._pool_usage_count += 1
        return batch

    # ------------------------------------------------------------------
    # Problem-specific definitions to be supplied by subclasses
    # ------------------------------------------------------------------
    @abstractmethod
    def generate_random_guess(self) -> ToyExampleSol:
        pass

    @abstractmethod
    def exact_solution(self) -> ToyExampleSol:
        pass

    @abstractmethod
    def _objective_torch(self, x_t: torch.Tensor, y_t: torch.Tensor, noise_samples_t: torch.Tensor) -> torch.Tensor:
        """Return scalar objective value using torch tensors."""

    @abstractmethod
    def _constraints_torch(self, x_t: torch.Tensor, y_t: torch.Tensor) -> torch.Tensor:
        """Return constraint tensor c(x, y) with shape (num_constraints,)."""

    # ------------------------------------------------------------------
    # Public numpy-facing APIs (auto-diff under the hood)
    # ------------------------------------------------------------------
    def constraint_residual(self, x: Array, y: Array) -> Array:
        with torch.no_grad():
            x_t = torch.as_tensor(x, dtype=torch.float64, device=self._torch_device)
            y_t = torch.as_tensor(y, dtype=torch.float64, device=self._torch_device)
            con = self._constraints_torch(x_t, y_t)
        return con.detach().cpu().numpy()

    def constraint_gradient_x(self, x: Array, y: Array) -> Array:
        grad_x, _ = self._autograd_constraint_grads(x, y)
        return grad_x

    def constraint_gradient_y(self, x: Array, y: Array) -> Array:
        _, grad_y = self._autograd_constraint_grads(x, y)
        return grad_y

    def project_x(self, vector: Array) -> Array:
        """Clip a vector to the problem-specific x box if provided."""
        bounds = getattr(self, "x_bounds", None)
        if bounds is None:
            return vector
        low, high = bounds
        return np.clip(vector, low, high)

    def project_y(self, vector: Array) -> Array:
        """Clip a vector to the problem-specific y box if provided."""
        bounds = getattr(self, "y_bounds", None)
        if bounds is None:
            return vector
        low, high = bounds
        return np.clip(vector, low, high)

    def gradient_x(self, x: Array, y: Array, *, samples: Array | None) -> Array:
        grad_x, _ = self._autograd_objective_grads(x, y, samples=samples)
        return grad_x

    def gradient_y(self, x: Array, y: Array, *, samples: Array | None) -> Array:
        _, grad_y = self._autograd_objective_grads(x, y, samples=samples)
        return grad_y

    def stochastic_objective(self, x: Array, y: Array, noise_samples: Array) -> float:
        with torch.no_grad():
            x_t = torch.as_tensor(x, dtype=torch.float64, device=self._torch_device)
            y_t = torch.as_tensor(y, dtype=torch.float64, device=self._torch_device)
            noise_t = torch.as_tensor(noise_samples, dtype=torch.float64, device=self._torch_device)
            val = self._objective_torch(x_t, y_t, noise_t)
        return float(val.detach().cpu().item())

    # ------------------------------------------------------------------
    # Internal helpers for auto-diff
    # ------------------------------------------------------------------
    def _autograd_objective_grads(self, x: Array, y: Array, *, samples: Array | None) -> tuple[Array, Array]:
        samples_np = self.sample_noise() if samples is None else samples
        x_t = torch.tensor(x, dtype=torch.float64, requires_grad=True, device=self._torch_device)
        y_t = torch.tensor(y, dtype=torch.float64, requires_grad=True, device=self._torch_device)
        samples_t = torch.as_tensor(samples_np, dtype=torch.float64, device=self._torch_device)

        obj = self._objective_torch(x_t, y_t, samples_t)
        gx_t, gy_t = torch.autograd.grad(obj, (x_t, y_t), allow_unused=True)

        if gx_t is None:
            gx_t = torch.zeros_like(x_t)
        if gy_t is None:
            gy_t = torch.zeros_like(y_t)

        gx = gx_t.detach().cpu().numpy()
        gy = gy_t.detach().cpu().numpy()
        return gx, gy

    def _autograd_constraint_grads(self, x: Array, y: Array) -> tuple[Array, Array]:
        x_t = torch.tensor(x, dtype=torch.float64, requires_grad=True, device=self._torch_device)
        y_t = torch.tensor(y, dtype=torch.float64, requires_grad=True, device=self._torch_device)

        def _con_x(z):
            return self._constraints_torch(z, y_t)

        def _con_y(z):
            return self._constraints_torch(x_t, z)

        jac_cx = torch.autograd.functional.jacobian(_con_x, x_t, vectorize=True)
        jac_cy = torch.autograd.functional.jacobian(_con_y, y_t, vectorize=True)

        grad_cx_np = jac_cx.detach().cpu().numpy()
        grad_cy_np = jac_cy.detach().cpu().numpy()
        return grad_cx_np, grad_cy_np
