import numpy as np
import math

class TargetFunction:
    def compute_function(self, X, a, b):
        raise NotImplementedError

class LinearFunction(TargetFunction):
    def compute_function(self, X: np.ndarray, a: float, b: float) -> float:
        # scale_factor = 0.3 / X.shape[0]  # The more parents, the smaller the weights
        noise = np.random.normal(0, 0.1, size=a.shape)
        scaled_weights = (a + noise + b) / np.sqrt(X.shape[0])
        return np.dot(X, scaled_weights)
    
    def __str__(self) -> str:
        return "Linear Target Function"
    
class PolynomialFunction(TargetFunction):
    def compute_function(self, X: np.ndarray, a: float, b: float, degree: int = 2) -> float:
        scale_factor = 1.0 / (X.shape[0] ** degree)
        return np.sum((a * X + b) ** degree, axis=0) * scale_factor
    
    def __str__(self) -> str:
        return "Polynomial Target Function"
    
class SineFunction(TargetFunction):
    def compute_function(self, X: np.ndarray, a=None, b=None) -> float:
        return np.sum(np.sin(X)) + np.random.normal(0, 0.1)
    
    def __str__(self):
        return "Sine Target Function"
    
class ThresholdFunction(TargetFunction):
    def compute_function(self, X: np.ndarray, a=None, b=None) -> float:
        return (np.sum(X) > 0).astype(float) + np.random.normal(0, 0.1)
    
    def __str__(self) -> str:
        return "Threshold Target Function"
    
class RadialBasisFunction(TargetFunction):
    def compute_function(self, X: np.ndarray, a=None, b=None) -> float:
        norm_sq = np.sum(X ** 2)
        sigma=1.0
        result = np.exp(-norm_sq / (2 * sigma ** 2))
        return result
    
    def __str__(self) -> str:
        return "Radial Basis Target Function"
    
class CheckerboardFunction(TargetFunction):
    def compute_function(self, X: np.ndarray, a=None, b=None) -> float:
       return np.sum(np.floor(X)) % 2
    
    def __str__(self):
        return "Checkerboard Target Function"