"""
Quadratic function implementation.
"""

import numpy as np
from .base import BlackBoxFunction

class QuadraticFunction(BlackBoxFunction):
    """Quadratic function implementation.
    
    This class implements a quadratic function of the form f(x) = x^T W x,
    where W is a weight matrix.
    
    Attributes:
        input_dim (int): The dimension of the input space.
        weights (numpy.ndarray): The weight matrix.
    """
    
    def __init__(self, input_dim=10, diagonal=True):
        """Initialize a quadratic function.
        
        Args:
            input_dim (int, optional): The dimension of the input space. Defaults to 10.
        """
        super().__init__(input_dim)
        # Initialize with diagonal matrix with random values
        # self.weights = np.diag(np.random.uniform(-1, 1, input_dim))
        self.weights = np.random.uniform(-1, 1, (input_dim, input_dim)) if not diagonal else np.diag(np.random.uniform(-1, 1, input_dim))

    def set_weights(self, weights):
        """Set the weight matrix.
        
        Args:
            weights (numpy.ndarray): The weight matrix.
        """
        self.weights = weights

    def _f(self, x):
        """Compute the function value.
        
        Args:
            x (numpy.ndarray): Input vector.
            
        Returns:
            float: Function value at x.
        """
        return x.T @ self.weights @ x

    def _grad(self, x):
        """Compute the gradient.
        
        Args:
            x (numpy.ndarray): Input vector.
            
        Returns:
            numpy.ndarray: Gradient at x.
        """
        return (x.T @ (self.weights + self.weights.T)).flatten() 