# benchmark_problems.py
# Comprehensive benchmark problem suite for multi-objective optimization
# Includes ZDT, DTLZ, and real-world problems with true Pareto front generation

import numpy as np
from scipy.stats import qmc
import matplotlib.pyplot as plt

def get_problem_function(problem_name, M=2):
    """Get problem function by name"""
    problem_map = {
        'ZDT1': zdt1,
        'ZDT2': zdt2,
        'ZDT3': zdt3,
        'ZDT4': zdt4,
        'ZDT6': zdt6,
        'DTLZ1': lambda x: dtlz1(x, M),
        'DTLZ2': lambda x: dtlz2(x, M),
        'DTLZ3': lambda x: dtlz3(x, M),
        'DTLZ4': lambda x: dtlz4(x, M),
        'DTLZ5': lambda x: dtlz5(x, M),
        'DTLZ6': lambda x: dtlz6(x, M),
        'DTLZ7': lambda x: dtlz7(x, M),
    }
    return problem_map.get(problem_name, zdt1)

def get_true_pareto_front(problem_name, M=2, n_points=100):
    """Get true Pareto front for benchmark problems"""
    if problem_name == 'ZDT1':
        return true_pareto_front_zdt1(n_points)
    elif problem_name == 'ZDT2':
        return true_pareto_front_zdt2(n_points)
    elif problem_name == 'ZDT3':
        return true_pareto_front_zdt3(n_points)
    elif problem_name == 'ZDT4':
        return true_pareto_front_zdt1(n_points)  # Same as ZDT1
    elif problem_name == 'ZDT6':
        return true_pareto_front_zdt6(n_points)
    elif problem_name.startswith('DTLZ'):
        if problem_name in ['DTLZ1', 'DTLZ2', 'DTLZ3', 'DTLZ4']:
            return true_pareto_front_dtlz124(M, n_points)
        elif problem_name == 'DTLZ5':
            return true_pareto_front_dtlz5(M, n_points)
        elif problem_name == 'DTLZ6':
            return true_pareto_front_dtlz5(M, n_points)  # Similar to DTLZ5
        elif problem_name == 'DTLZ7':
            return true_pareto_front_dtlz7(M, n_points)
    
    # Default fallback
    return true_pareto_front_zdt1(n_points)

# ====================== ZDT Problems ======================

def zdt1(x):
    """ZDT1: f1 = x1, f2 = g(x) * h(f1, g(x))"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    f1 = x[:, 0]
    g = 1 + 9 * np.sum(x[:, 1:], axis=1) / (D - 1)
    h = 1 - np.sqrt(f1 / g)
    f2 = g * h
    
    return np.column_stack([f1, f2])

def zdt2(x):
    """ZDT2: f1 = x1, f2 = g(x) * h(f1, g(x))"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    f1 = x[:, 0]
    g = 1 + 9 * np.sum(x[:, 1:], axis=1) / (D - 1)
    h = 1 - (f1 / g) ** 2
    f2 = g * h
    
    return np.column_stack([f1, f2])

def zdt3(x):
    """ZDT3: f1 = x1, f2 = g(x) * h(f1, g(x))"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    f1 = x[:, 0]
    g = 1 + 9 * np.sum(x[:, 1:], axis=1) / (D - 1)
    h = 1 - np.sqrt(f1 / g) - (f1 / g) * np.sin(10 * np.pi * f1)
    f2 = g * h
    
    return np.column_stack([f1, f2])

def zdt4(x):
    """ZDT4: f1 = x1, f2 = g(x) * h(f1, g(x))"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    f1 = x[:, 0]
    
    # For ZDT4, x1 ∈ [0,1] and xi ∈ [-5,5] for i = 2,...,D
    # We assume input is scaled to [0,1] and transform internally
    x_scaled = x.copy()
    x_scaled[:, 1:] = 10 * x_scaled[:, 1:] - 5  # Scale to [-5, 5]
    
    g = 1 + 10 * (D - 1) + np.sum(x_scaled[:, 1:]**2 - 10 * np.cos(4 * np.pi * x_scaled[:, 1:]), axis=1)
    h = 1 - np.sqrt(f1 / g)
    f2 = g * h
    
    return np.column_stack([f1, f2])

def zdt6(x):
    """ZDT6: f1 = 1 - exp(-4*x1) * sin^6(6*pi*x1), f2 = g(x) * h(f1, g(x))"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    f1 = 1 - np.exp(-4 * x[:, 0]) * np.sin(6 * np.pi * x[:, 0])**6
    g = 1 + 9 * (np.sum(x[:, 1:], axis=1) / (D - 1))**0.25
    h = 1 - (f1 / g)**2
    f2 = g * h
    
    return np.column_stack([f1, f2])

# ====================== DTLZ Problems ======================

def dtlz1(x, M=2):
    """DTLZ1: Linear front"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    k = D - M + 1
    g = 100 * (k + np.sum((x[:, M-1:] - 0.5)**2 - np.cos(20 * np.pi * (x[:, M-1:] - 0.5)), axis=1))
    
    f = np.zeros((N, M))
    for i in range(M):
        f[:, i] = 0.5 * (1 + g)
        for j in range(M - i - 1):
            f[:, i] *= x[:, j]
        if i > 0:
            f[:, i] *= (1 - x[:, M - i - 1])
    
    return f

def dtlz2(x, M=2):
    """DTLZ2: Spherical front"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    k = D - M + 1
    g = np.sum((x[:, M-1:] - 0.5)**2, axis=1)
    
    f = np.zeros((N, M))
    for i in range(M):
        f[:, i] = (1 + g)
        for j in range(M - i - 1):
            f[:, i] *= np.cos(x[:, j] * np.pi / 2)
        if i > 0:
            f[:, i] *= np.sin(x[:, M - i - 1] * np.pi / 2)
    
    return f

def dtlz3(x, M=2):
    """DTLZ3: Spherical front with local optima"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    k = D - M + 1
    g = 100 * (k + np.sum((x[:, M-1:] - 0.5)**2 - np.cos(20 * np.pi * (x[:, M-1:] - 0.5)), axis=1))
    
    f = np.zeros((N, M))
    for i in range(M):
        f[:, i] = (1 + g)
        for j in range(M - i - 1):
            f[:, i] *= np.cos(x[:, j] * np.pi / 2)
        if i > 0:
            f[:, i] *= np.sin(x[:, M - i - 1] * np.pi / 2)
    
    return f

def dtlz4(x, M=2, alpha=100):
    """DTLZ4: Biased spherical front"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    k = D - M + 1
    g = np.sum((x[:, M-1:] - 0.5)**2, axis=1)
    
    f = np.zeros((N, M))
    for i in range(M):
        f[:, i] = (1 + g)
        for j in range(M - i - 1):
            f[:, i] *= np.cos(x[:, j]**alpha * np.pi / 2)
        if i > 0:
            f[:, i] *= np.sin(x[:, M - i - 1]**alpha * np.pi / 2)
    
    return f

def dtlz5(x, M=2):
    """DTLZ5: Degenerate front"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    k = D - M + 1
    g = np.sum((x[:, M-1:] - 0.5)**2, axis=1)
    
    # Compute theta
    theta = np.zeros((N, M-1))
    theta[:, 0] = x[:, 0] * np.pi / 2
    for i in range(1, M-1):
        theta[:, i] = (1 + 2 * g[:, np.newaxis] * x[:, i:i+1]) * np.pi / (4 * (1 + g[:, np.newaxis]))
    
    f = np.zeros((N, M))
    for i in range(M):
        f[:, i] = (1 + g)
        for j in range(M - i - 1):
            f[:, i] *= np.cos(theta[:, j])
        if i > 0:
            f[:, i] *= np.sin(theta[:, M - i - 1])
    
    return f

def dtlz6(x, M=2):
    """DTLZ6: Degenerate front with biased g"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    k = D - M + 1
    g = np.sum(x[:, M-1:]**0.1, axis=1)
    
    # Compute theta
    theta = np.zeros((N, M-1))
    theta[:, 0] = x[:, 0] * np.pi / 2
    for i in range(1, M-1):
        theta[:, i] = (1 + 2 * g[:, np.newaxis] * x[:, i:i+1]) * np.pi / (4 * (1 + g[:, np.newaxis]))
    
    f = np.zeros((N, M))
    for i in range(M):
        f[:, i] = (1 + g)
        for j in range(M - i - 1):
            f[:, i] *= np.cos(theta[:, j])
        if i > 0:
            f[:, i] *= np.sin(theta[:, M - i - 1])
    
    return f

def dtlz7(x, M=2):
    """DTLZ7: Disconnected front"""
    x = np.atleast_2d(x)
    N, D = x.shape
    
    k = D - M + 1
    g = 1 + 9 * np.sum(x[:, M-1:], axis=1) / k
    
    f = np.zeros((N, M))
    for i in range(M-1):
        f[:, i] = x[:, i]
    
    # Last objective
    sum_term = np.sum(f[:, :M-1] / (1 + g[:, np.newaxis]) * (1 + np.sin(3 * np.pi * f[:, :M-1])), axis=1)
    f[:, M-1] = (1 + g) * (M - sum_term)
    
    return f

# ====================== True Pareto Fronts ======================

def true_pareto_front_zdt1(n_points=100):
    """True Pareto front for ZDT1"""
    f1 = np.linspace(0, 1, n_points)
    f2 = 1 - np.sqrt(f1)
    return np.column_stack([f1, f2])

def true_pareto_front_zdt2(n_points=100):
    """True Pareto front for ZDT2"""
    f1 = np.linspace(0, 1, n_points)
    f2 = 1 - f1**2
    return np.column_stack([f1, f2])

def true_pareto_front_zdt3(n_points=100):
    """True Pareto front for ZDT3 (discontinuous)"""
    # ZDT3 has a discontinuous front
    regions = [
        (0.0, 0.0830),
        (0.1822, 0.2577),
        (0.4093, 0.4538),
        (0.6183, 0.6525),
        (0.8233, 0.8518)
    ]
    
    front_points = []
    points_per_region = n_points // len(regions)
    
    for start, end in regions:
        f1_region = np.linspace(start, end, points_per_region)
        f2_region = 1 - np.sqrt(f1_region) - f1_region * np.sin(10 * np.pi * f1_region)
        front_points.append(np.column_stack([f1_region, f2_region]))
    
    return np.vstack(front_points)

def true_pareto_front_zdt6(n_points=100):
    """True Pareto front for ZDT6"""
    f1 = np.linspace(0.2807753191, 1.0, n_points)  # f1 range for ZDT6
    f2 = 1 - f1**2
    return np.column_stack([f1, f2])

def true_pareto_front_dtlz124(M=2, n_points=100):
    """True Pareto front for DTLZ1, DTLZ2, DTLZ3, DTLZ4"""
    if M == 2:
        # 2D case: quarter circle
        theta = np.linspace(0, np.pi/2, n_points)
        if M == 2:
            f1 = np.cos(theta)
            f2 = np.sin(theta) 
            return np.column_stack([f1, f2])
    else:
        # Generate points on unit simplex for M > 2
        # Use Dirichlet distribution sampling
        alpha = np.ones(M)
        samples = np.random.dirichlet(alpha, n_points)
        
        # For DTLZ1: scale by 0.5
        if M > 2:
            samples *= 0.5
        
        return samples

def true_pareto_front_dtlz5(M=2, n_points=100):
    """True Pareto front for DTLZ5 (degenerate)"""
    if M == 2:
        theta = np.linspace(0, np.pi/4, n_points)  # Degenerate to quarter of the curve
        f1 = np.cos(theta)
        f2 = np.sin(theta)
        return np.column_stack([f1, f2])
    else:
        # For M > 2, it's a curve rather than a surface
        t = np.linspace(0, 1, n_points)
        front = np.zeros((n_points, M))
        front[:, 0] = np.cos(t * np.pi / 2)
        front[:, 1] = np.sin(t * np.pi / 2)
        # Other objectives are zero for the true front
        return front

def true_pareto_front_dtlz7(M=2, n_points=100):
    """True Pareto front for DTLZ7 (disconnected)"""
    if M == 2:
        # For 2 objectives, generate the disconnected regions
        f1_values = []
        f2_values = []
        
        # Multiple disconnected regions
        for k in range(1, 2**M):  # 2^(M-1) regions
            f1_region = np.linspace(0, 1, n_points // (2**(M-1)))
            f2_region = 2 * (M - f1_region * (1 + np.sin(3 * np.pi * f1_region)))
            
            # Filter valid points
            valid = f2_region >= 0
            f1_values.extend(f1_region[valid])
            f2_values.extend(f2_region[valid])
        
        return np.column_stack([f1_values, f2_values])
    else:
        # For M > 2, more complex structure
        # Simplified approximation
        samples = np.random.rand(n_points, M)
        samples[:, -1] = 2 * (M - np.sum(samples[:, :-1], axis=1))
        return samples

# ====================== Problem Utilities ======================

def get_problem_bounds(problem_name, D):
    """Get variable bounds for problems"""
    if problem_name == 'ZDT4':
        # ZDT4 has mixed bounds: x1 ∈ [0,1], xi ∈ [-5,5] for i > 1
        bounds = np.zeros((D, 2))
        bounds[0, :] = [0, 1]
        bounds[1:, :] = [-5, 5]
        return bounds
    else:
        # Most problems have [0,1]^D bounds
        return np.array([[0, 1]] * D)

def scale_to_bounds(x, bounds):
    """Scale [0,1]^D variables to actual problem bounds"""
    x_scaled = x.copy()
    for i in range(len(bounds)):
        x_scaled[:, i] = bounds[i, 0] + x[:, i] * (bounds[i, 1] - bounds[i, 0])
    return x_scaled

def evaluate_problem(problem_name, x):
    """Evaluate problem with proper scaling"""
    x = np.atleast_2d(x)
    D = x.shape[1]
    
    # Get bounds and scale if necessary
    bounds = get_problem_bounds(problem_name, D)
    if not np.allclose(bounds, [0, 1]):
        x_scaled = scale_to_bounds(x, bounds)
    else:
        x_scaled = x
    
    # Get problem function
    if problem_name.startswith('DTLZ'):
        M = 2  # Default, should be specified
        problem_func = get_problem_function(problem_name, M)
    else:
        problem_func = get_problem_function(problem_name)
    
    return problem_func(x_scaled)

# ====================== Visualization Functions ======================

def plot_pareto_front(problem_name, M=2, n_points=100):
    """Plot true Pareto front for visualization"""
    true_front = get_true_pareto_front(problem_name, M, n_points)
    
    if true_front.shape[1] == 2:
        plt.figure(figsize=(8, 6))
        plt.plot(true_front[:, 0], true_front[:, 1], 'k-', linewidth=2, label='True Pareto Front')
        plt.xlabel('Objective 1')
        plt.ylabel('Objective 2')
        plt.title(f'{problem_name} True Pareto Front')
        plt.legend()
        plt.grid(True, alpha=0.7)
        plt.show()
    elif true_front.shape[1] == 3:
        from mpl_toolkits.mplot3d import Axes3D
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')
        ax.plot(true_front[:, 0], true_front[:, 1], true_front[:, 2], 'k-', linewidth=2)
        ax.set_xlabel('Objective 1')
        ax.set_ylabel('Objective 2')
        ax.set_zlabel('Objective 3')
        ax.set_title(f'{problem_name} True Pareto Front')
        plt.show()

def analyze_problem_characteristics(problem_name, D=30, M=2, n_samples=1000):
    """Analyze problem characteristics for algorithm design"""
    print(f"=== Problem Analysis: {problem_name} ({D}D, {M}M) ===")
    
    # Generate random samples
    X = np.random.rand(n_samples, D)
    problem_func = get_problem_function(problem_name, M)
    Y = problem_func(X)
    
    # Basic statistics
    print(f"Objective ranges:")
    for i in range(M):
        print(f"  f{i+1}: [{Y[:, i].min():.4f}, {Y[:, i].max():.4f}]")
    
    # Pareto front analysis
    from u_rankmoea import nondominated_frontpoints
    front = nondominated_frontpoints(Y)
    print(f"Non-dominated points: {len(front)}/{n_samples} ({len(front)/n_samples*100:.1f}%)")
    
    # Correlation analysis
    if M == 2:
        correlation = np.corrcoef(Y[:, 0], Y[:, 1])[0, 1]
        print(f"Objective correlation: {correlation:.4f}")
    
    # True front comparison
    try:
        true_front = get_true_pareto_front(problem_name, M)
        if len(true_front) > 0:
            print(f"True front points generated: {len(true_front)}")
            if M == 2:
                print(f"True front f1 range: [{true_front[:, 0].min():.4f}, {true_front[:, 0].max():.4f}]")
                print(f"True front f2 range: [{true_front[:, 1].min():.4f}, {true_front[:, 1].max():.4f}]")
    except:
        print("True front generation failed")
    
    print()

# ====================== Real-world Problems ======================

def engineering_design_problem(x):
    """Simple engineering design problem (pressure vessel design)"""
    x = np.atleast_2d(x)
    # Scale variables to actual ranges
    # x1, x2: thickness [0.0625, 6.1875], x3, x4: dimensions [10, 200]
    x_scaled = x.copy()
    x_scaled[:, :2] = 0.0625 + x[:, :2] * (6.1875 - 0.0625)
    x_scaled[:, 2:] = 10 + x[:, 2:] * (200 - 10)
    
    R, L = x_scaled[:, 2], x_scaled[:, 3]
    Ts, Th = x_scaled[:, 0], x_scaled[:, 1]
    
    # Cost objective (minimize)
    f1 = 0.6224 * Ts * R * L + 1.7781 * Th * R**2 + 3.1661 * Ts**2 * L + 19.84 * Ts**2 * R
    
    # Weight objective (minimize)  
    f2 = np.pi * R**2 * L * 0.0193 + (4/3) * np.pi * R**3 * 0.0193
    
    return np.column_stack([f1, f2])

def portfolio_optimization_problem(x, returns_data=None):
    """Portfolio optimization problem (risk vs return)"""
    x = np.atleast_2d(x)
    n_assets = x.shape[1]
    
    # Normalize weights to sum to 1
    weights = x / np.sum(x, axis=1, keepdims=True)
    
    # Generate or use provided returns data
    if returns_data is None:
        np.random.seed(42)
        returns_data = np.random.multivariate_normal(
            mean=np.random.uniform(0.05, 0.15, n_assets),
            cov=np.random.uniform(0.01, 0.05, (n_assets, n_assets)),
            size=252  # One year of daily returns
        )
    
    # Calculate expected return and risk
    expected_returns = np.mean(returns_data, axis=0)
    cov_matrix = np.cov(returns_data.T)
    
    portfolio_returns = np.dot(weights, expected_returns)
    portfolio_risks = np.sqrt(np.sum(weights * np.dot(weights, cov_matrix), axis=1))
    
    # Objectives: maximize return (minimize negative), minimize risk
    f1 = -portfolio_returns  # Negative for minimization
    f2 = portfolio_risks
    
    return np.column_stack([f1, f2])

if __name__ == "__main__":
    """Test benchmark problems"""
    print("=== Benchmark Problems Test ===")
    
    # Test all problems
    problems = ['ZDT1', 'ZDT2', 'ZDT3', 'ZDT4', 'ZDT6', 'DTLZ1', 'DTLZ2', 'DTLZ3', 'DTLZ4', 'DTLZ7']
    
    for problem_name in problems:
        try:
            # Test function evaluation
            X = np.random.rand(10, 30)
            problem_func = get_problem_function(problem_name, 2)
            Y = problem_func(X)
            
            # Test true front generation
            true_front = get_true_pareto_front(problem_name, 2)
            
            print(f"✓ {problem_name}: Y shape={Y.shape}, True front shape={true_front.shape}")
            
        except Exception as e:
            print(f"✗ {problem_name}: Error - {e}")
    
    # Analyze a few problems
    print("\n=== Problem Characteristics ===")
    for problem in ['ZDT1', 'DTLZ2']:
        analyze_problem_characteristics(problem, D=10, M=2, n_samples=500)
    
    # Visualize some fronts
    print("=== Visualization Test ===")
    try:
        plot_pareto_front('ZDT1', M=2)
        plot_pareto_front('DTLZ2', M=2)
    except:
        print("Visualization requires matplotlib display")
    
    print("Benchmark problems test completed!")