from ucimlrepo import fetch_ucirepo
import numpy as np
import math
import operator
from collections import OrderedDict

# Synthetic benchmark functions
def koza_1(x):
    """Koza-1 benchmark: f(x) = x^4 + x^3 + x^2 + x"""
    return x**4 + x**3 + x**2 + x

def keizjer_7(x):
    return math.log(x)

def korns_11(x):
    # 6.87 + 11 cos(7.23x3)
    return 6.87 + 11 * math.cos(7.23 * x**3)

def nguyen_7(x):
    """Nguyen-7 benchmark: f(x) = ln(x+1) + ln(x^2+1)"""
    return math.log(x+1) + math.log(x**2+1)

def nguyen_11(x, y):
    return x**y

def nguyen_5(x):
    return math.sin(x**2) * math.cos(x) - 1

def keizer_11(x, y):
    return x*y + math.sin((x-1)*(y-1))

def vladislavleva_5(x, y, z):
    return 30*((x-1)*(z-1))/(y**2*(x-10))

def keijzer_5(x, y, z):
    return 30*x*z/((x-10)*y**2)

def vladislavleva_1(x, y):
    return math.exp(-(x-1)**2) / (1.2+(y-2.5)**2)

# Dictionary of synthetic benchmark functions with their input ranges
synthetic_ranges = {
    "koza_1": [(-1.3, 1.3)],
    "korns_11": [(-1, 1)],
    "nguyen_7": [(-1.3, 1.3)],
    "salustowicz": [(-1.3, 1.3)],
    "keizjer_7": [(0.0001, 4)],
    "nguyen_11": [(1, 5), (1, 5)],
    "keizer_11": [(-1, 1), (-1, 1)],
    "vladislavleva_5": [(0, 4), (1, 3), (0, 4)],
    "keijzer_5": [(-2, 2), (1, 3), (-2, 2)],
    "vladislavleva_1": [(-1, 1), (-1, 1)],
    "nguyen_5": [(-1.6, 1.6)]
}

# Dictionary mapping synthetic function names to functions
synthetic_functions = OrderedDict({
    "korns_11": korns_11,
    "keizjer_7": keizjer_7,
    "vladislavleva_5": vladislavleva_5,
    "keijzer_5": keijzer_5,
    "vladislavleva_1": vladislavleva_1,
    "nguyen_11": nguyen_11,
    "keizer_11": keizer_11,
    "nguyen_5": nguyen_5
})

# Real dataset functions
def airfoil_self_noise():
    """Airfoil Self-Noise Dataset"""
    airfoil_self_noise = fetch_ucirepo(id=291)
    X = airfoil_self_noise.data.features
    y = airfoil_self_noise.data.targets
    return X, y

def concrete_compressive_strength():
    """Concrete Compressive Strength Dataset"""
    concrete_compressive_strength = fetch_ucirepo(id=165)
    X = concrete_compressive_strength.data.features
    y = concrete_compressive_strength.data.targets
    return X, y

def combined_cycle_power_plant():
    """Combined Cycle Power Plant Dataset"""
    combined_cycle_power_plant = fetch_ucirepo(id=294)
    X = combined_cycle_power_plant.data.features
    y = combined_cycle_power_plant.data.targets
    return X, y

def energy_efficiency_heating_load():
    """Energy Efficiency Heating Load Dataset"""
    energy_efficiency = fetch_ucirepo(id=242)
    X = energy_efficiency.data.features
    y = energy_efficiency.data.targets['Y1']
    return X, y

def energy_efficiency_cooling_load():
    """Energy Efficiency Cooling Load Dataset"""
    energy_efficiency = fetch_ucirepo(id=242)
    X = energy_efficiency.data.features
    y = energy_efficiency.data.targets['Y2']
    return X, y

# Dictionary mapping real dataset names to functions
real_datasets = OrderedDict({
    "airfoil_self_noise": airfoil_self_noise,
    "concrete_compressive_strength": concrete_compressive_strength,
    "combined_cycle_power_plant": combined_cycle_power_plant,
    "energy_efficiency_heating_load": energy_efficiency_heating_load,
    "energy_efficiency_cooling_load": energy_efficiency_cooling_load
})

def get_dataset(name, num_points=1000, random_seed=42):
    """
    Get either a synthetic or real dataset by name.
    
    Args:
        name: Name of the dataset (synthetic or real)
        num_points: Number of points to generate for synthetic datasets
        random_seed: Random seed for reproducibility
        
    Returns:
        tuple: (X, y) where X is the input features and y is the target values
    """
    # Check if it's a synthetic dataset
    if name in synthetic_functions:
        target_func = synthetic_functions[name]
        variable_ranges = synthetic_ranges[name]
        num_variables = len(variable_ranges)
        
        # Generate synthetic data points
        if num_variables == 1:
            X = np.linspace(variable_ranges[0][0], variable_ranges[0][1], num_points)
            y = np.array([target_func(x) for x in X])
        else:
            np.random.seed(random_seed)
            X = np.zeros((num_points, num_variables))
            for i in range(num_variables):
                X[:, i] = np.random.uniform(
                    variable_ranges[i][0], 
                    variable_ranges[i][1], 
                    num_points
                )
            y = np.array([target_func(*x) for x in X])
            
        return X, y
    
    # Check if it's a real dataset
    elif name in real_datasets:
        X, y = real_datasets[name]()
        # normalize X and y featurewise
        # X = (X - X.mean(axis=0)) / X.std(axis=0)
        # y = (y - y.mean()) / y.std()
        return X.to_numpy(), y.to_numpy().squeeze()
    
    else:
        raise ValueError(f"Unknown dataset name: {name}")

def get_all_dataset_names(function_type):
    """Get a list of all available dataset names (both synthetic and real)."""
    if function_type == "synthetic":
        return list(synthetic_functions.keys())
    elif function_type == "real":
        return list(real_datasets.keys())
    else:
        return list(synthetic_functions.keys()) + list(real_datasets.keys())

if __name__ == "__main__":
    # Test synthetic dataset
    X_syn, y_syn = get_dataset("vladislavleva_5", num_points=100)
    print("Synthetic dataset shape:", X_syn.shape, y_syn.shape)
    
    # Test real dataset
    X_real, y_real = get_dataset("airfoil_self_noise")
    print("Real dataset shape:", X_real.shape, y_real.shape)
    print(type(X_real))


    
    # Print all available datasets
    print("\nAvailable datasets:")
    for name in get_all_dataset_names("real"):
        print(f"- {name}") 