"""
Quartic function implementation.
"""

import numpy as np
from .base import BlackBoxFunction

class QuarticFunction(BlackBoxFunction):
    """Quartic function implementation.
    
    This class implements a quartic function of the form f(x) = a * sum(x_i^4),
    where a is a scalar coefficient.
    
    Attributes:
        input_dim (int): The dimension of the input space.
        a (float): The coefficient for the quartic term.
    """
    
    def __init__(self, input_dim=10, a=1.0):
        """Initialize a quartic function.
        
        Args:
            input_dim (int, optional): The dimension of the input space. Defaults to 10.
            a (float, optional): The coefficient for the quartic term. Defaults to 1.0.
        """
        super().__init__(input_dim)
        self.a = a

    def set_coefficient(self, a):
        """Set the coefficient for the quartic term.
        
        Args:
            a (float): The coefficient for the quartic term.
        """
        self.a = a

    def _f(self, x):
        """Compute the function value.
        
        Args:
            x (numpy.ndarray): Input vector.
            
        Returns:
            float: Function value at x.
        """
        return self.a * np.sum(x**4)

    def _grad(self, x):
        """Compute the gradient.
        
        Args:
            x (numpy.ndarray): Input vector.
            
        Returns:
            numpy.ndarray: Gradient at x.
        """
        return 4 * self.a * x**3 