"""
Logistic regression loss function implementation.
"""

import numpy as np
from .base import BlackBoxFunction

class LogisticLossFunction(BlackBoxFunction):
    """Logistic regression loss function implementation.
    
    This class implements a logistic regression loss function of the form:
    f(w) = (1/n) * sum_i log(1 + exp(-y_i * w^T * x_i))
    
    where:
    - w is the weight vector (parameter to optimize)
    - x_i are the feature vectors
    - y_i are the labels (-1 or 1)
    - n is the number of samples
    
    Attributes:
        input_dim (int): The dimension of the input space (feature dimension).
        X (numpy.ndarray): The feature matrix of shape (n_samples, input_dim).
        y (numpy.ndarray): The label vector of shape (n_samples,).
    """
    
    def __init__(self, input_dim=10, n_samples=100, random_state=None):
        """Initialize a logistic regression loss function.
        
        Args:
            input_dim (int, optional): The dimension of the input space. Defaults to 10.
            n_samples (int, optional): Number of samples to generate. Defaults to 100.
            random_state (int, optional): Random seed for reproducibility. Defaults to None.
        """
        super().__init__(input_dim)
        
        # Set random seed if provided
        if random_state is not None:
            np.random.seed(random_state)
        
        # Generate synthetic data
        self.X = np.random.randn(n_samples, input_dim)
        
        # Generate true weights for data generation
        true_weights = np.random.randn(input_dim)
        
        # Generate labels
        logits = self.X @ true_weights
        probs = 1 / (1 + np.exp(-logits))
        self.y = 2 * (np.random.random(n_samples) < probs) - 1  # Convert to -1, 1
    
    def set_data(self, X, y):
        """Set custom data for the loss function.
        
        Args:
            X (numpy.ndarray): The feature matrix of shape (n_samples, input_dim).
            y (numpy.ndarray): The label vector of shape (n_samples,).
        """
        assert X.shape[1] == self.input_dim, f"Feature dimension mismatch: {X.shape[1]} != {self.input_dim}"
        assert X.shape[0] == y.shape[0], f"Sample count mismatch: {X.shape[0]} != {y.shape[0]}"
        
        self.X = X
        self.y = y

    def _f(self, w):
        """Compute the logistic loss function value.
        
        Args:
            w (numpy.ndarray): Weight vector.
            
        Returns:
            float: Function value at w.
        """
        # Compute margin: y_i * w^T * x_i
        margins = self.y * (self.X @ w)
        
        # Compute logistic loss: log(1 + exp(-margin))
        losses = np.log(1 + np.exp(-margins))
        
        # Return average loss
        return np.mean(losses)

    def _grad(self, w):
        """Compute the gradient of the logistic loss.
        
        Args:
            w (numpy.ndarray): Weight vector.
            
        Returns:
            numpy.ndarray: Gradient at w.
        """
        # Compute margin: y_i * w^T * x_i
        margins = self.y * (self.X @ w)
        
        # Compute sigmoid: 1 / (1 + exp(margin))
        sigmoid = 1 / (1 + np.exp(margins))
        
        # Compute gradient: -(1/n) * sum_i (y_i * x_i * sigmoid(-y_i * w^T * x_i))
        grad = -np.mean(self.X * (self.y * sigmoid)[:, np.newaxis], axis=0)
        
        return grad 