#!/usr/bin/python3
"""
Implements naive first-order gradient ascent.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
from pydantic import BaseModel
from typing import Any, Dict, Final, List, NamedTuple, Optional, Tuple

from .base import BaseOptimizer
from .state import OptimizerState
from ..envs.base import BaseTask
from ..utils import Pydantic2Rd, Rd2Pydantic


class GradAscentOptimizer(BaseOptimizer):
    optimizer_name: str = "GradAscent"

    def __init__(
        self,
        task: BaseTask,
        ref: NamedTuple,
        batch_size: int,
        num_steps_per_acq: int = 4,
        eta: float = 0.01,
        seed: int = 2025,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            ref: the reference individual that we are optimizing for.
            batch_size: the number of new designs to return per sampling step.
            optimizer: the iterative optimizer.
            optimizer_kwargs: optional keyword arguments for the optimizer.
            num_steps_per_acq: the number of iterative gradient steps to take
                per acquisition step. Default 4.
            eta: step size. Default 0.01.
            seed: random seed. Default 2025.
            device: device. Default `auto`.
        """
        del kwargs
        super(GradAscentOptimizer, self).__init__(task, batch_size, seed=seed)
        self.ref: Final[NamedTuple] = ref
        self.num_steps_per_acq = num_steps_per_acq
        self.eta = eta
        self.X: Optional[np.ndarray] = None

    def forward(
        self,
        state: OptimizerState,
        knowledge: Dict[str, str],
        **kwargs: Dict[str, Any]
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate of shape B, where B is the
            sampling batch size, and a dictionary of metadata.
        """
        del state, knowledge, kwargs
        assert self.X is not None
        for _ in range(self.num_steps_per_acq):
            grad = self.backward(self.X)
            grad = np.where(np.isnan(grad), 0.0, grad)
            self.X = self.X + (grad * self.eta)
        candidates = Rd2Pydantic(self.task, self.X, continuous=True)
        self.X = None
        return candidates, {}

    def fit(self, X: List[BaseModel], y: np.ndarray) -> None:
        """
        Performs any pre-acquisition steps.
        Input:
            X: a list of all prior evaluated designs.
            y: an array of shape N of all objective evaluations, where N is the
                number of designs.
        Returns:
            None.
        """
        Xval = Pydantic2Rd(self.task, X, continuous=True)
        self.X = Xval[np.argsort(y.squeeze())[-self.batch_size:]]
        if self.X.ndim == 1:
            self.X = self.X.reshape(-1, 1)

    def backward(self, x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
        """
        Computes the backward difference approximation of the gradient of the
        OOD task function at a batch of points.
        Input:
            x: an array of shape ND of the designs to evaluate, where N is
                the number of designs and D is the number of design dimensions.
            eps: jitter for gradient estimation. Default 1e-6.
        Returns:
            A array of shape ND of the gradient of the function at the points.
        """
        assert x.ndim == 2
        x = x.astype(np.float64)
        grad = np.zeros_like(x)

        for j in range(x.shape[1]):
            xp, xm = x.copy(), x.copy()
            xp[:,  j] += eps
            xm[:, j] -= eps

            xplus = self.task.extend(
                Rd2Pydantic(self.task, xp, continuous=True), self.ref
            )
            xminus = self.task.extend(
                Rd2Pydantic(self.task, xm, continuous=True), self.ref
            )

            y_plus = np.array(self.task(xplus))
            y_minus = np.array(self.task(xminus))

            grad[:, j] = (y_plus - y_minus) / (2.0 * eps)

        return grad
